Skip to content

Commit 14d5243

Browse files
authored
proxyprotocol: unify the ip family before sending Proxy Protocol (#1055)
Signed-off-by: Yang Keao <yangkeao@chunibyo.icu>
1 parent f92e628 commit 14d5243

File tree

3 files changed

+127
-17
lines changed

3 files changed

+127
-17
lines changed

pkg/proxy/net/proxy_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ func TestProxyReadWrite(t *testing.T) {
5656
n, err := prw.Read(data)
5757
require.NoError(t, err)
5858
require.Equal(t, len(message), n)
59-
require.Equal(t, p.SrcAddress, prw.Proxy().SrcAddress)
59+
60+
parsedAddr, ok := prw.Proxy().SrcAddress.(*net.TCPAddr)
61+
require.True(t, ok)
62+
require.Equal(t, addr.IP.To4(), parsedAddr.IP)
63+
require.Equal(t, addr.Port, parsedAddr.Port)
6064
require.Equal(t, addr.String(), prw.RemoteAddr().String())
6165
require.Equal(t, proxyAddr, prw.ProxyAddr().String())
6266
}, 1)

pkg/proxy/proxyprotocol/proxy.go

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,31 +36,35 @@ func (p *Proxy) ToBytes() ([]byte, error) {
3636

3737
switch sadd := srcAddr.(type) {
3838
case *net.TCPAddr:
39-
addressFamily = ProxyAFINet
40-
if len(sadd.IP) == net.IPv6len {
41-
addressFamily = ProxyAFINet6
42-
}
43-
network = ProxyNetworkStream
4439
dadd, ok := dstAddr.(*net.TCPAddr)
4540
if !ok {
4641
return nil, ErrAddressFamilyMismatch
4742
}
48-
buf = append(buf, sadd.IP...)
49-
buf = append(buf, dadd.IP...)
50-
buf = append(buf, byte(sadd.Port>>8), byte(sadd.Port))
51-
buf = append(buf, byte(dadd.Port>>8), byte(dadd.Port))
52-
case *net.UDPAddr:
43+
saddUnifiedIP, daddUnifiedIP := unifyIPFamily(sadd.IP, dadd.IP)
44+
5345
addressFamily = ProxyAFINet
54-
if len(sadd.IP) == net.IPv6len {
46+
if len(saddUnifiedIP) == net.IPv6len {
5547
addressFamily = ProxyAFINet6
5648
}
57-
network = ProxyNetworkDgram
49+
network = ProxyNetworkStream
50+
buf = append(buf, saddUnifiedIP...)
51+
buf = append(buf, daddUnifiedIP...)
52+
buf = append(buf, byte(sadd.Port>>8), byte(sadd.Port))
53+
buf = append(buf, byte(dadd.Port>>8), byte(dadd.Port))
54+
case *net.UDPAddr:
5855
dadd, ok := dstAddr.(*net.UDPAddr)
5956
if !ok {
6057
return nil, ErrAddressFamilyMismatch
6158
}
62-
buf = append(buf, sadd.IP...)
63-
buf = append(buf, dadd.IP...)
59+
saddUnifiedIP, daddUnifiedIP := unifyIPFamily(sadd.IP, dadd.IP)
60+
61+
addressFamily = ProxyAFINet
62+
if len(saddUnifiedIP) == net.IPv6len {
63+
addressFamily = ProxyAFINet6
64+
}
65+
network = ProxyNetworkDgram
66+
buf = append(buf, saddUnifiedIP...)
67+
buf = append(buf, daddUnifiedIP...)
6468
buf = append(buf, byte(sadd.Port>>8), byte(sadd.Port))
6569
buf = append(buf, byte(dadd.Port>>8), byte(dadd.Port))
6670
case *net.UnixAddr:
@@ -94,6 +98,19 @@ func (p *Proxy) ToBytes() ([]byte, error) {
9498
return buf, nil
9599
}
96100

101+
// unifyIPFamily unifies the IP family of ip1 and ip2.
102+
// If both of them are IPv4 (or IPv4 mapped IPv6), return the IPv4 addresses.
103+
// Else, convert both of them to IPv6 and return.
104+
func unifyIPFamily(ip1 net.IP, ip2 net.IP) (net.IP, net.IP) {
105+
ip1To4 := ip1.To4()
106+
ip2To4 := ip2.To4()
107+
if ip1To4 != nil && ip2To4 != nil {
108+
return ip1To4, ip2To4
109+
}
110+
111+
return ip1.To16(), ip2.To16()
112+
}
113+
97114
func ParseProxyV2(rd io.Reader) (m *Proxy, n int, err error) {
98115
var hdr [4]byte
99116

pkg/proxy/proxyprotocol/proxy_test.go

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,15 @@ func TestProxyParse(t *testing.T) {
5050
p, _, err := ParseProxyV2(srv)
5151
require.NoError(t, err)
5252
require.NotNil(t, p)
53-
require.Equal(t, tcpaddr, p.SrcAddress)
54-
require.Equal(t, tcpaddr, p.DstAddress)
53+
54+
srcAddr, ok := p.SrcAddress.(*net.TCPAddr)
55+
require.True(t, ok)
56+
require.Equal(t, tcpaddr.IP.To4(), srcAddr.IP)
57+
require.Equal(t, tcpaddr.Port, srcAddr.Port)
58+
dstAddr, ok := p.DstAddress.(*net.TCPAddr)
59+
require.True(t, ok)
60+
require.Equal(t, tcpaddr.IP.To4(), dstAddr.IP)
61+
require.Equal(t, tcpaddr.Port, dstAddr.Port)
5562
require.Equal(t, ProxyVersion2, p.Version)
5663
require.Equal(t, ProxyCommandLocal, p.Command)
5764
require.Len(t, p.TLV, 2)
@@ -92,3 +99,85 @@ func TestProxyToBytes(t *testing.T) {
9299
_, err = hdr.ToBytes()
93100
require.NoError(t, err)
94101
}
102+
103+
func TestMixIPv4AndIPv6ProxyToBytes(t *testing.T) {
104+
tests := []struct {
105+
srcIP net.IP
106+
dstIP net.IP
107+
srcPort int
108+
dstPort int
109+
wantAF ProxyAddressFamily
110+
wantIPLen int
111+
}{
112+
{
113+
srcIP: net.ParseIP("192.168.1.1"),
114+
dstIP: net.ParseIP("192.168.1.2"),
115+
srcPort: 1234,
116+
dstPort: 5678,
117+
wantAF: ProxyAFINet,
118+
wantIPLen: net.IPv4len,
119+
},
120+
{
121+
srcIP: net.ParseIP("2001:db8::1"),
122+
dstIP: net.ParseIP("2001:db8::2"),
123+
srcPort: 1234,
124+
dstPort: 5678,
125+
wantAF: ProxyAFINet6,
126+
wantIPLen: net.IPv6len,
127+
},
128+
{
129+
srcIP: net.ParseIP("192.168.1.1"),
130+
dstIP: net.ParseIP("2001:db8::1"),
131+
srcPort: 1234,
132+
dstPort: 5678,
133+
wantAF: ProxyAFINet6,
134+
wantIPLen: net.IPv6len,
135+
},
136+
{
137+
srcIP: net.ParseIP("192.168.1.1"),
138+
dstIP: net.ParseIP("::ffff:192.168.1.2"),
139+
srcPort: 1234,
140+
dstPort: 5678,
141+
wantAF: ProxyAFINet,
142+
wantIPLen: net.IPv4len,
143+
},
144+
{
145+
srcIP: net.ParseIP("::ffff:192.168.1.1"),
146+
dstIP: net.ParseIP("::ffff:192.168.1.2"),
147+
srcPort: 1234,
148+
dstPort: 5678,
149+
wantAF: ProxyAFINet,
150+
wantIPLen: net.IPv4len,
151+
},
152+
{
153+
srcIP: net.ParseIP("::ffff:192.168.1.1"),
154+
dstIP: net.ParseIP("2001:db8::1"),
155+
srcPort: 1234,
156+
dstPort: 5678,
157+
wantAF: ProxyAFINet6,
158+
wantIPLen: net.IPv6len,
159+
},
160+
}
161+
162+
for _, tt := range tests {
163+
hdr := &Proxy{
164+
Version: ProxyVersion2,
165+
Command: ProxyCommandProxy,
166+
SrcAddress: &net.TCPAddr{IP: tt.srcIP, Port: tt.srcPort},
167+
DstAddress: &net.TCPAddr{IP: tt.dstIP, Port: tt.dstPort},
168+
}
169+
170+
hdrBytes, err := hdr.ToBytes()
171+
require.NoError(t, err)
172+
require.GreaterOrEqual(t, len(hdrBytes), len(MagicV2)+4)
173+
174+
addressFamily := ProxyAddressFamily(hdrBytes[len(MagicV2)+1] >> 4)
175+
require.Equal(t, tt.wantAF, addressFamily)
176+
177+
length := int(hdrBytes[len(MagicV2)+2])<<8 | int(hdrBytes[len(MagicV2)+3])
178+
require.Equal(t, len(hdrBytes)-4-len(MagicV2), length)
179+
180+
expectedPayloadSize := tt.wantIPLen*2 + 4
181+
require.Equal(t, expectedPayloadSize, length)
182+
}
183+
}

0 commit comments

Comments
 (0)