Skip to content

Commit a0047f5

Browse files
authored
refine: refine smux conn #40 (#41)
* refine: refine smux conn * add cancel in copybuffer * use round_robin as lb algorithm #31 #40
1 parent 6ee3a16 commit a0047f5

File tree

14 files changed

+143
-278
lines changed

14 files changed

+143
-278
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ go get -u "github.com/Ehco1996/ehco/cmd/ehco"
8686
}
8787
```
8888

89-
## Benchmark
89+
## Benchmark(Apple m1)
9090

9191
iperf:
9292

@@ -140,5 +140,5 @@ iperf3 -c 0.0.0.0 -p 1234 -u -b 1G --length 1024
140140

141141
| iperf | raw | relay(raw) | relay(ws) |relay(wss) | relay(mwss)|
142142
| ---- | ---- | ---- | ---- | ---- | ---- |
143-
| tcp | 62.6 Gbits/sec | 23.9 Gbits/sec | 14.65 Gbits/sec | 4.22 Gbits/sec | 2.43 Gbits/sec |
143+
| tcp | 123 Gbits/sec | 55 Gbits/sec | 41 Gbits/sec | 10 Gbits/sec | 5.78 Gbits/sec |
144144
| udp | 14.5 Gbits/sec | 3.3 Gbits/sec | 直接转发 | 直接转发 | 直接转发 |

internal/lb/lb.go

Lines changed: 0 additions & 98 deletions
This file was deleted.

internal/lb/lb_test.go

Lines changed: 0 additions & 39 deletions
This file was deleted.

internal/lb/round_robin.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package lb
2+
3+
import (
4+
"sync/atomic"
5+
)
6+
7+
// RoundRobin is an interface for representing round-robin balancing.
8+
type RoundRobin interface {
9+
Next() string
10+
}
11+
12+
type roundrobin struct {
13+
remotes []string
14+
next uint32
15+
}
16+
17+
func NewRBRemotes(remotes []string) RoundRobin {
18+
return &roundrobin{remotes: remotes}
19+
}
20+
21+
func (r *roundrobin) Next() string {
22+
n := atomic.AddUint32(&r.next, 1)
23+
return r.remotes[(int(n)-1)%len(r.remotes)]
24+
}

internal/lb/round_robin_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package lb
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func Test_roundrobin_Next(t *testing.T) {
8+
remotes := []string{
9+
"127.0.0.1",
10+
"127.0.0.2",
11+
"127.0.0.3",
12+
}
13+
rb := NewRBRemotes(remotes)
14+
for i := 0; i < len(remotes); i++ {
15+
if res := rb.Next(); res != remotes[i] {
16+
t.Fatalf("need %s got %s", remotes[i], res)
17+
}
18+
}
19+
}

internal/relay/relay.go

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ func NewRelay(cfg *config.RelayConfig) (*Relay, error) {
4747

4848
TP: transporter.PickTransporter(
4949
cfg.TransportType,
50-
lb.New(cfg.TCPRemotes),
51-
lb.New(cfg.UDPRemotes),
50+
lb.NewRBRemotes(cfg.TCPRemotes),
51+
lb.NewRBRemotes(cfg.UDPRemotes),
5252
),
5353
}
5454

@@ -179,35 +179,33 @@ func (r *Relay) RunLocalWSSServer() error {
179179
func (r *Relay) RunLocalMWSSServer() error {
180180
r.LogRelay()
181181
tp := r.TP.(*transporter.Raw)
182-
s := &transporter.MWSSServer{
183-
ConnChan: make(chan net.Conn, 1024),
184-
ErrChan: make(chan error, 1),
185-
}
182+
mwssServer := transporter.NewMWSSServer()
186183
mux := mux.NewRouter()
187184
mux.Handle("/", http.HandlerFunc(web.Index))
188-
mux.Handle("/mwss/", http.HandlerFunc(s.Upgrade))
189-
server := &http.Server{
185+
mux.Handle("/mwss/", http.HandlerFunc(mwssServer.Upgrade))
186+
httpServer := &http.Server{
190187
Addr: r.LocalTCPAddr.String(),
191188
Handler: mux,
192189
TLSConfig: mytls.DefaultTLSConfig,
193190
ReadHeaderTimeout: 30 * time.Second,
194191
}
195-
s.Server = server
192+
mwssServer.Server = httpServer
196193

197194
ln, err := net.Listen("tcp", r.LocalTCPAddr.String())
198195
if err != nil {
199196
return err
200197
}
201198
go func() {
202-
err := server.Serve(tls.NewListener(ln, server.TLSConfig))
199+
err := httpServer.Serve(tls.NewListener(ln, httpServer.TLSConfig))
203200
if err != nil {
204-
s.ErrChan <- err
201+
mwssServer.ErrChan <- err
205202
}
206-
close(s.ErrChan)
203+
close(mwssServer.ErrChan)
207204
}()
205+
208206
var tempDelay time.Duration
209207
for {
210-
conn, e := s.Accept()
208+
conn, e := mwssServer.Accept()
211209
if e != nil {
212210
if ne, ok := e.(net.Error); ok && ne.Temporary() {
213211
if tempDelay == 0 {

internal/transporter/buffer.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package transporter
22

33
import (
4+
"context"
45
"errors"
56
"io"
67
"net"
@@ -72,15 +73,28 @@ func copyBuffer(dst io.Writer, src io.Reader, bufferPool *sync.Pool) (written in
7273

7374
// NOTE must call setdeadline before use this func or may goroutine leak
7475
func transport(rw1, rw2 io.ReadWriter) error {
76+
ctx, cancel := context.WithCancel(context.Background())
77+
defer cancel()
78+
7579
errc := make(chan error, 1)
7680
go func() {
77-
_, err := copyBuffer(rw1, rw2, InboundBufferPool)
78-
errc <- err
81+
select {
82+
case <-ctx.Done():
83+
println("ctx done exits copy1")
84+
default:
85+
_, err := copyBuffer(rw1, rw2, InboundBufferPool)
86+
errc <- err
87+
}
7988
}()
8089

8190
go func() {
82-
_, err := copyBuffer(rw2, rw1, OutboundBufferPool)
83-
errc <- err
91+
select {
92+
case <-ctx.Done():
93+
println("ctx done exit copy1")
94+
default:
95+
_, err := copyBuffer(rw2, rw1, InboundBufferPool)
96+
errc <- err
97+
}
8498
}()
8599
err := <-errc
86100
// NOTE 我们不关心operror 比如 eof/reset/broken pipe

0 commit comments

Comments
 (0)