Skip to content

Commit ed89915

Browse files
committed
Add support for Multipath TCP
1 parent fec96a3 commit ed89915

File tree

9 files changed

+102
-42
lines changed

9 files changed

+102
-42
lines changed

src/admin/getpeers.go

+10-8
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ type PeerEntry struct {
2424
PublicKey string `json:"key"`
2525
Port uint64 `json:"port"`
2626
Priority uint64 `json:"priority"`
27+
Multipath bool `json:"multipath,omitempty"`
2728
RXBytes DataUnit `json:"bytes_recvd,omitempty"`
2829
TXBytes DataUnit `json:"bytes_sent,omitempty"`
2930
Uptime float64 `json:"uptime,omitempty"`
@@ -37,14 +38,15 @@ func (a *AdminSocket) getPeersHandler(req *GetPeersRequest, res *GetPeersRespons
3738
res.Peers = make([]PeerEntry, 0, len(peers))
3839
for _, p := range peers {
3940
peer := PeerEntry{
40-
Port: p.Port,
41-
Up: p.Up,
42-
Inbound: p.Inbound,
43-
Priority: uint64(p.Priority), // can't be uint8 thanks to gobind
44-
URI: p.URI,
45-
RXBytes: DataUnit(p.RXBytes),
46-
TXBytes: DataUnit(p.TXBytes),
47-
Uptime: p.Uptime.Seconds(),
41+
Port: p.Port,
42+
Up: p.Up,
43+
Inbound: p.Inbound,
44+
Priority: uint64(p.Priority), // can't be uint8 thanks to gobind
45+
Multipath: p.Multipath,
46+
URI: p.URI,
47+
RXBytes: DataUnit(p.RXBytes),
48+
TXBytes: DataUnit(p.TXBytes),
49+
Uptime: p.Uptime.Seconds(),
4850
}
4951
if p.Latency > 0 {
5052
peer.Latency = p.Latency

src/core/api.go

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type PeerInfo struct {
3030
Coords []uint64
3131
Port uint64
3232
Priority uint8
33+
Multipath bool
3334
RXBytes uint64
3435
TXBytes uint64
3536
Uptime time.Duration
@@ -87,6 +88,7 @@ func (c *Core) GetPeers() []PeerInfo {
8788
peerinfo.RXBytes = atomic.LoadUint64(&c.rx)
8889
peerinfo.TXBytes = atomic.LoadUint64(&c.tx)
8990
peerinfo.Uptime = time.Since(c.up)
91+
peerinfo.Multipath = isMPTCP(c)
9092
}
9193
if p, ok := conns[conn]; ok {
9294
peerinfo.Key = p.Key

src/core/link.go

+40-13
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ type links struct {
4444

4545
type linkProtocol interface {
4646
dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error)
47-
listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error)
47+
listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error)
4848
}
4949

5050
// linkInfo is used as a map key
@@ -72,6 +72,7 @@ type linkOptions struct {
7272
tlsSNI string
7373
password []byte
7474
maxBackoff time.Duration
75+
multipath bool
7576
}
7677

7778
type Listener struct {
@@ -140,6 +141,7 @@ const ErrLinkPinnedKeyInvalid = linkError("pinned public key is invalid")
140141
const ErrLinkPasswordInvalid = linkError("password is invalid")
141142
const ErrLinkUnrecognisedSchema = linkError("link schema unknown")
142143
const ErrLinkMaxBackoffInvalid = linkError("max backoff duration invalid")
144+
const ErrLinkMultipathInvalid = linkError("multipath invalid")
143145

144146
func (l *links) add(u *url.URL, sintf string, linkType linkType) error {
145147
var retErr error
@@ -193,6 +195,17 @@ func (l *links) add(u *url.URL, sintf string, linkType linkType) error {
193195
}
194196
options.maxBackoff = d
195197
}
198+
if p := u.Query().Get("multipath"); p != "" {
199+
switch p {
200+
case "true", "1":
201+
options.multipath = true
202+
case "false", "0":
203+
options.multipath = false
204+
default:
205+
retErr = ErrLinkMultipathInvalid
206+
return
207+
}
208+
}
196209
// SNI headers must contain hostnames and not IP addresses, so we must make sure
197210
// that we do not populate the SNI with an IP literal. We do this by splitting
198211
// the host-port combo from the query option and then seeing if it parses to an
@@ -379,7 +392,7 @@ func (l *links) add(u *url.URL, sintf string, linkType linkType) error {
379392
return retErr
380393
}
381394

382-
func (l *links) remove(u *url.URL, sintf string, linkType linkType) error {
395+
func (l *links) remove(u *url.URL, sintf string, _ linkType) error {
383396
var retErr error
384397
phony.Block(l, func() {
385398
// Generate the link info and see whether we think we already
@@ -422,31 +435,45 @@ func (l *links) listen(u *url.URL, sintf string) (*Listener, error) {
422435
cancel()
423436
return nil, ErrLinkUnrecognisedSchema
424437
}
425-
listener, err := protocol.listen(ctx, u, sintf)
426-
if err != nil {
427-
cancel()
428-
return nil, err
429-
}
430-
li := &Listener{
431-
listener: listener,
432-
ctx: ctx,
433-
Cancel: cancel,
434-
}
435438

436439
var options linkOptions
437440
if p := u.Query().Get("priority"); p != "" {
438441
pi, err := strconv.ParseUint(p, 10, 8)
439442
if err != nil {
443+
cancel()
440444
return nil, ErrLinkPriorityInvalid
441445
}
442446
options.priority = uint8(pi)
443447
}
444448
if p := u.Query().Get("password"); p != "" {
445449
if len(p) > blake2b.Size {
450+
cancel()
446451
return nil, ErrLinkPasswordInvalid
447452
}
448453
options.password = []byte(p)
449454
}
455+
if p := u.Query().Get("multipath"); p != "" {
456+
switch p {
457+
case "true", "1":
458+
options.multipath = true
459+
case "false", "0":
460+
options.multipath = false
461+
default:
462+
cancel()
463+
return nil, ErrLinkMultipathInvalid
464+
}
465+
}
466+
467+
listener, err := protocol.listen(ctx, u, sintf, options)
468+
if err != nil {
469+
cancel()
470+
return nil, err
471+
}
472+
li := &Listener{
473+
listener: listener,
474+
ctx: ctx,
475+
Cancel: cancel,
476+
}
450477

451478
go func() {
452479
l.core.log.Infof("%s listener started on %s", strings.ToUpper(u.Scheme), listener.Addr())
@@ -567,7 +594,7 @@ func (l *links) handler(linkType linkType, options linkOptions, conn net.Conn, s
567594
switch {
568595
case err != nil:
569596
return fmt.Errorf("write handshake: %w", err)
570-
case err == nil && n != len(metaBytes):
597+
case n != len(metaBytes):
571598
return fmt.Errorf("incomplete handshake send")
572599
}
573600
meta = version_metadata{}

src/core/link_quic.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func (l *linkQUIC) dial(ctx context.Context, url *url.URL, info linkInfo, option
6565
}, nil
6666
}
6767

68-
func (l *linkQUIC) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) {
68+
func (l *linkQUIC) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) {
6969
ql, err := quic.ListenAddr(url.Host, l.tlsconfig, l.quicconfig)
7070
if err != nil {
7171
return nil, err

src/core/link_socks.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,6 @@ func (l *linkSOCKS) dial(_ context.Context, url *url.URL, info linkInfo, options
4747
return conn, nil
4848
}
4949

50-
func (l *linkSOCKS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) {
50+
func (l *linkSOCKS) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) {
5151
return nil, fmt.Errorf("SOCKS listener not supported")
5252
}

src/core/link_tcp.go

+13-6
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ type tcpDialer struct {
3636
addr *net.TCPAddr
3737
}
3838

39-
func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error) {
39+
func (l *linkTCP) dialersFor(url *url.URL, info linkInfo, options linkOptions) ([]*tcpDialer, error) {
4040
host, p, err := net.SplitHostPort(url.Host)
4141
if err != nil {
4242
return nil, err
@@ -55,7 +55,7 @@ func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error)
5555
IP: ip,
5656
Port: port,
5757
}
58-
dialer, err := l.dialerFor(addr, info.sintf)
58+
dialer, err := l.dialerFor(addr, info.sintf, options.multipath)
5959
if err != nil {
6060
continue
6161
}
@@ -69,7 +69,7 @@ func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error)
6969
}
7070

7171
func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
72-
dialers, err := l.dialersFor(url, info)
72+
dialers, err := l.dialersFor(url, info, options)
7373
if err != nil {
7474
return nil, err
7575
}
@@ -88,17 +88,21 @@ func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options
8888
return nil, err
8989
}
9090

91-
func (l *linkTCP) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) {
91+
func (l *linkTCP) listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error) {
9292
hostport := url.Host
9393
if sintf != "" {
9494
if host, port, err := net.SplitHostPort(hostport); err == nil {
9595
hostport = fmt.Sprintf("[%s%%%s]:%s", host, sintf, port)
9696
}
9797
}
98-
return l.listenconfig.Listen(ctx, "tcp", hostport)
98+
lc := *l.listenconfig
99+
if options.multipath {
100+
setMPTCPForListener(&lc)
101+
}
102+
return lc.Listen(ctx, "tcp", hostport)
99103
}
100104

101-
func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error) {
105+
func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string, mptcp bool) (*net.Dialer, error) {
102106
if dst.IP.IsLinkLocalUnicast() {
103107
if sintf != "" {
104108
dst.Zone = sintf
@@ -112,6 +116,9 @@ func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error)
112116
KeepAlive: -1,
113117
Control: l.tcpContext,
114118
}
119+
if mptcp {
120+
setMPTCPForDialer(dialer)
121+
}
115122
if sintf != "" {
116123
dialer.Control = l.getControl(sintf)
117124
ief, err := net.InterfaceByName(sintf)

src/core/link_tcp_mptcp.go

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package core
2+
3+
import (
4+
"crypto/tls"
5+
"net"
6+
)
7+
8+
func setMPTCPForDialer(d *net.Dialer) {
9+
d.SetMultipathTCP(true)
10+
}
11+
12+
func setMPTCPForListener(lc *net.ListenConfig) {
13+
lc.SetMultipathTCP(true)
14+
}
15+
16+
func isMPTCP(c net.Conn) bool {
17+
switch tc := c.(type) {
18+
case *net.TCPConn:
19+
mp, _ := tc.MultipathTCP()
20+
return mp
21+
case *tls.Conn:
22+
if tc, ok := tc.NetConn().(*net.TCPConn); ok {
23+
mp, _ := tc.MultipathTCP()
24+
return mp
25+
}
26+
return false
27+
default:
28+
return false
29+
}
30+
}

src/core/link_tls.go

+4-12
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package core
33
import (
44
"context"
55
"crypto/tls"
6-
"fmt"
76
"net"
87
"net/url"
98

@@ -34,7 +33,7 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS {
3433
}
3534

3635
func (l *linkTLS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
37-
dialers, err := l.tcp.dialersFor(url, info)
36+
dialers, err := l.tcp.dialersFor(url, info, options)
3837
if err != nil {
3938
return nil, err
4039
}
@@ -58,17 +57,10 @@ func (l *linkTLS) dial(ctx context.Context, url *url.URL, info linkInfo, options
5857
return nil, err
5958
}
6059

61-
func (l *linkTLS) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) {
62-
hostport := url.Host
63-
if sintf != "" {
64-
if host, port, err := net.SplitHostPort(hostport); err == nil {
65-
hostport = fmt.Sprintf("[%s%%%s]:%s", host, sintf, port)
66-
}
67-
}
68-
listener, err := l.listener.Listen(ctx, "tcp", hostport)
60+
func (l *linkTLS) listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error) {
61+
listener, err := l.tcp.listen(ctx, url, sintf, options)
6962
if err != nil {
7063
return nil, err
7164
}
72-
tlslistener := tls.NewListener(listener, l.config)
73-
return tlslistener, nil
65+
return tls.NewListener(listener, l.config), nil
7466
}

src/core/link_unix.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,6 @@ func (l *linkUNIX) dial(ctx context.Context, url *url.URL, info linkInfo, option
4040
return l.dialer.DialContext(ctx, "unix", addr.String())
4141
}
4242

43-
func (l *linkUNIX) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) {
43+
func (l *linkUNIX) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) {
4444
return l.listener.Listen(ctx, "unix", url.Path)
4545
}

0 commit comments

Comments
 (0)