Skip to content

Commit 404cab9

Browse files
authored
[client] Redirect dns forwarder port 5353 to new listening port 22054 (#4707)
- Port dnat changes from #4015 (nftables/iptables/userspace) - For userspace: rewrite the original port to the target port - Remember original destination port in conntrack - Rewrite the source port back to the original port for replies - Redirect incoming port 5353 to 22054 (tcp/udp) - Revert port changes based on the network map received from management - Adjust tracer to show NAT stages
1 parent 4545ab9 commit 404cab9

File tree

25 files changed

+1125
-196
lines changed

25 files changed

+1125
-196
lines changed

client/firewall/iptables/manager_linux.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
260260
return m.router.UpdateSet(set, prefixes)
261261
}
262262

263+
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
264+
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
265+
m.mutex.Lock()
266+
defer m.mutex.Unlock()
267+
268+
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
269+
}
270+
271+
// RemoveInboundDNAT removes an inbound DNAT rule.
272+
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
273+
m.mutex.Lock()
274+
defer m.mutex.Unlock()
275+
276+
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
277+
}
278+
263279
func getConntrackEstablished() []string {
264280
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
265281
}

client/firewall/iptables/router_linux.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
880880
return nberrors.FormatErrorOrNil(merr)
881881
}
882882

883+
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
884+
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
885+
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
886+
887+
if _, exists := r.rules[ruleID]; exists {
888+
return nil
889+
}
890+
891+
dnatRule := []string{
892+
"-i", r.wgIface.Name(),
893+
"-p", strings.ToLower(string(protocol)),
894+
"--dport", strconv.Itoa(int(sourcePort)),
895+
"-d", localAddr.String(),
896+
"-m", "addrtype", "--dst-type", "LOCAL",
897+
"-j", "DNAT",
898+
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
899+
}
900+
901+
ruleInfo := ruleInfo{
902+
table: tableNat,
903+
chain: chainRTRDR,
904+
rule: dnatRule,
905+
}
906+
907+
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
908+
return fmt.Errorf("add inbound DNAT rule: %w", err)
909+
}
910+
r.rules[ruleID] = ruleInfo.rule
911+
912+
r.updateState()
913+
return nil
914+
}
915+
916+
// RemoveInboundDNAT removes an inbound DNAT rule.
917+
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
918+
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
919+
920+
if dnatRule, exists := r.rules[ruleID]; exists {
921+
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
922+
return fmt.Errorf("delete inbound DNAT rule: %w", err)
923+
}
924+
delete(r.rules, ruleID)
925+
}
926+
927+
r.updateState()
928+
return nil
929+
}
930+
883931
func applyPort(flag string, port *firewall.Port) []string {
884932
if port == nil {
885933
return nil

client/firewall/manager/firewall.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,20 @@ type Manager interface {
151151

152152
DisableRouting() error
153153

154-
// AddDNATRule adds a DNAT rule
154+
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network.
155155
AddDNATRule(ForwardRule) (Rule, error)
156156

157-
// DeleteDNATRule deletes a DNAT rule
157+
// DeleteDNATRule deletes the outbound DNAT rule.
158158
DeleteDNATRule(Rule) error
159159

160160
// UpdateSet updates the set with the given prefixes
161161
UpdateSet(hash Set, prefixes []netip.Prefix) error
162+
163+
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
164+
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
165+
166+
// RemoveInboundDNAT removes inbound DNAT rule
167+
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
162168
}
163169

164170
func GenKey(format string, pair RouterPair) string {

client/firewall/nftables/manager_linux.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
376376
return m.router.UpdateSet(set, prefixes)
377377
}
378378

379+
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
380+
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
381+
m.mutex.Lock()
382+
defer m.mutex.Unlock()
383+
384+
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
385+
}
386+
387+
// RemoveInboundDNAT removes an inbound DNAT rule.
388+
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
389+
m.mutex.Lock()
390+
defer m.mutex.Unlock()
391+
392+
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
393+
}
394+
379395
func (m *Manager) createWorkTable() (*nftables.Table, error) {
380396
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
381397
if err != nil {

client/firewall/nftables/router_linux.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,6 +1350,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
13501350
return nil
13511351
}
13521352

1353+
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
1354+
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
1355+
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
1356+
1357+
if _, exists := r.rules[ruleID]; exists {
1358+
return nil
1359+
}
1360+
1361+
protoNum, err := protoToInt(protocol)
1362+
if err != nil {
1363+
return fmt.Errorf("convert protocol to number: %w", err)
1364+
}
1365+
1366+
exprs := []expr.Any{
1367+
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
1368+
&expr.Cmp{
1369+
Op: expr.CmpOpEq,
1370+
Register: 1,
1371+
Data: ifname(r.wgIface.Name()),
1372+
},
1373+
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
1374+
&expr.Cmp{
1375+
Op: expr.CmpOpEq,
1376+
Register: 2,
1377+
Data: []byte{protoNum},
1378+
},
1379+
&expr.Payload{
1380+
DestRegister: 3,
1381+
Base: expr.PayloadBaseTransportHeader,
1382+
Offset: 2,
1383+
Len: 2,
1384+
},
1385+
&expr.Cmp{
1386+
Op: expr.CmpOpEq,
1387+
Register: 3,
1388+
Data: binaryutil.BigEndian.PutUint16(sourcePort),
1389+
},
1390+
}
1391+
1392+
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
1393+
1394+
exprs = append(exprs,
1395+
&expr.Immediate{
1396+
Register: 1,
1397+
Data: localAddr.AsSlice(),
1398+
},
1399+
&expr.Immediate{
1400+
Register: 2,
1401+
Data: binaryutil.BigEndian.PutUint16(targetPort),
1402+
},
1403+
&expr.NAT{
1404+
Type: expr.NATTypeDestNAT,
1405+
Family: uint32(nftables.TableFamilyIPv4),
1406+
RegAddrMin: 1,
1407+
RegProtoMin: 2,
1408+
RegProtoMax: 0,
1409+
},
1410+
)
1411+
1412+
dnatRule := &nftables.Rule{
1413+
Table: r.workTable,
1414+
Chain: r.chains[chainNameRoutingRdr],
1415+
Exprs: exprs,
1416+
UserData: []byte(ruleID),
1417+
}
1418+
r.conn.AddRule(dnatRule)
1419+
1420+
if err := r.conn.Flush(); err != nil {
1421+
return fmt.Errorf("add inbound DNAT rule: %w", err)
1422+
}
1423+
1424+
r.rules[ruleID] = dnatRule
1425+
1426+
return nil
1427+
}
1428+
1429+
// RemoveInboundDNAT removes an inbound DNAT rule.
1430+
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
1431+
if err := r.refreshRulesMap(); err != nil {
1432+
return fmt.Errorf(refreshRulesMapError, err)
1433+
}
1434+
1435+
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
1436+
1437+
if rule, exists := r.rules[ruleID]; exists {
1438+
if err := r.conn.DelRule(rule); err != nil {
1439+
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
1440+
}
1441+
if err := r.conn.Flush(); err != nil {
1442+
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
1443+
}
1444+
delete(r.rules, ruleID)
1445+
}
1446+
1447+
return nil
1448+
}
1449+
13531450
// applyNetwork generates nftables expressions for networks (CIDR) or sets
13541451
func (r *router) applyNetwork(
13551452
network firewall.Network,

client/firewall/uspfilter/conntrack/common.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ type BaseConnTrack struct {
2222
PacketsRx atomic.Uint64
2323
BytesTx atomic.Uint64
2424
BytesRx atomic.Uint64
25+
26+
DNATOrigPort atomic.Uint32
2527
}
2628

2729
// these small methods will be inlined by the compiler

client/firewall/uspfilter/conntrack/tcp.go

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
157157
return tracker
158158
}
159159

160-
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
160+
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
161161
key := ConnKey{
162162
SrcIP: srcIP,
163163
DstIP: dstIP,
@@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
171171

172172
if exists {
173173
t.updateState(key, conn, flags, direction, size)
174-
return key, true
174+
return key, uint16(conn.DNATOrigPort.Load()), true
175175
}
176176

177-
return key, false
177+
return key, 0, false
178178
}
179179

180-
// TrackOutbound records an outbound TCP connection
181-
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
182-
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
183-
// if (inverted direction) conn is not tracked, track this direction
184-
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
180+
// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed
181+
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 {
182+
if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists {
183+
return origPort
185184
}
185+
// if (inverted direction) conn is not tracked, track this direction
186+
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0)
187+
return 0
186188
}
187189

188190
// TrackInbound processes an inbound TCP packet and updates connection state
189-
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
190-
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
191+
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) {
192+
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort)
191193
}
192194

193195
// track is the common implementation for tracking both inbound and outbound connections
194-
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
195-
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
196+
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
197+
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
196198
if exists || flags&TCPSyn == 0 {
197199
return
198200
}
@@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
210212

211213
conn.tombstone.Store(false)
212214
conn.state.Store(int32(TCPStateNew))
215+
conn.DNATOrigPort.Store(uint32(origPort))
213216

214-
t.logger.Trace2("New %s TCP connection: %s", direction, key)
217+
if origPort != 0 {
218+
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
219+
} else {
220+
t.logger.Trace2("New %s TCP connection: %s", direction, key)
221+
}
215222
t.updateState(key, conn, flags, direction, size)
216223

217224
t.mutex.Lock()
@@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() {
449456
}
450457
}
451458

459+
// GetConnection safely retrieves a connection state
460+
func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) {
461+
t.mutex.RLock()
462+
defer t.mutex.RUnlock()
463+
464+
key := ConnKey{
465+
SrcIP: srcIP,
466+
DstIP: dstIP,
467+
SrcPort: srcPort,
468+
DstPort: dstPort,
469+
}
470+
conn, exists := t.connections[key]
471+
return conn, exists
472+
}
473+
452474
// Close stops the cleanup routine and releases resources
453475
func (t *TCPTracker) Close() {
454476
t.tickerCancel()

client/firewall/uspfilter/conntrack/tcp_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
603603
serverPort := uint16(80)
604604

605605
// 1. Client sends SYN (we receive it as inbound)
606-
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
606+
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
607607

608608
key := ConnKey{
609609
SrcIP: clientIP,
@@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
623623
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
624624

625625
// 3. Client sends ACK to complete handshake
626-
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
626+
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
627627
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
628628

629629
// 4. Test data transfer
630630
// Client sends data
631-
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
631+
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0)
632632

633633
// Server sends ACK for data
634634
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
@@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
637637
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
638638

639639
// Client sends ACK for data
640-
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
640+
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
641641

642642
// Verify state and counters
643643
require.Equal(t, TCPStateEstablished, conn.GetState())

0 commit comments

Comments
 (0)