Skip to content

Commit 86eff0d

Browse files
authored
[client] Fix netstack dns forwarder (#4727)
1 parent 43c9a51 commit 86eff0d

File tree

6 files changed

+231
-60
lines changed

6 files changed

+231
-60
lines changed

client/firewall/uspfilter/filter.go

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ const (
5050

5151
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
5252

53+
// serviceKey represents a protocol/port combination for netstack service registry
54+
type serviceKey struct {
55+
protocol gopacket.LayerType
56+
port uint16
57+
}
58+
5359
// RuleSet is a set of rules grouped by a string key
5460
type RuleSet map[string]PeerRule
5561

@@ -113,6 +119,9 @@ type Manager struct {
113119
portDNATEnabled atomic.Bool
114120
portDNATRules []portDNATRule
115121
portDNATMutex sync.RWMutex
122+
123+
netstackServices map[serviceKey]struct{}
124+
netstackServiceMutex sync.RWMutex
116125
}
117126

118127
// decoder for packages
@@ -203,6 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
203212
localForwarding: enableLocalForwarding,
204213
dnatMappings: make(map[netip.Addr]netip.Addr),
205214
portDNATRules: []portDNATRule{},
215+
netstackServices: make(map[serviceKey]struct{}),
206216
}
207217
m.routingEnabled.Store(false)
208218

@@ -838,9 +848,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
838848
return true
839849
}
840850

841-
// If requested we pass local traffic to internal interfaces to the forwarder.
842-
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
843-
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
851+
if m.shouldForward(d, dstIP) {
844852
return m.handleForwardedLocalTraffic(packetData)
845853
}
846854

@@ -1274,3 +1282,86 @@ func (m *Manager) DisableRouting() error {
12741282

12751283
return nil
12761284
}
1285+
1286+
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
1287+
func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) {
1288+
m.netstackServiceMutex.Lock()
1289+
defer m.netstackServiceMutex.Unlock()
1290+
layerType := m.protocolToLayerType(protocol)
1291+
key := serviceKey{protocol: layerType, port: port}
1292+
m.netstackServices[key] = struct{}{}
1293+
m.logger.Debug3("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType)
1294+
m.logger.Debug1("RegisterNetstackService: current registry size: %d", len(m.netstackServices))
1295+
}
1296+
1297+
// UnregisterNetstackService removes a service from the netstack registry
1298+
func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) {
1299+
m.netstackServiceMutex.Lock()
1300+
defer m.netstackServiceMutex.Unlock()
1301+
layerType := m.protocolToLayerType(protocol)
1302+
key := serviceKey{protocol: layerType, port: port}
1303+
delete(m.netstackServices, key)
1304+
m.logger.Debug2("Unregistered netstack service on protocol %s port %d", protocol, port)
1305+
}
1306+
1307+
// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use
1308+
func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType {
1309+
switch protocol {
1310+
case nftypes.TCP:
1311+
return layers.LayerTypeTCP
1312+
case nftypes.UDP:
1313+
return layers.LayerTypeUDP
1314+
case nftypes.ICMP:
1315+
return layers.LayerTypeICMPv4
1316+
default:
1317+
return gopacket.LayerType(0) // Invalid/unknown
1318+
}
1319+
}
1320+
1321+
// shouldForward determines if a packet should be forwarded to the forwarder.
1322+
// The forwarder handles routing packets to the native OS network stack.
1323+
// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly.
1324+
func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
1325+
// not enabled, never forward
1326+
if !m.localForwarding {
1327+
return false
1328+
}
1329+
1330+
// netstack always needs to forward because it's lacking a native interface
1331+
// exception for registered netstack services, those should go to netstack listeners
1332+
if m.netstack {
1333+
return !m.hasMatchingNetstackService(d)
1334+
}
1335+
1336+
// traffic to our other local interfaces (not NetBird IP) - always forward
1337+
if dstIP != m.wgIface.Address().IP {
1338+
return true
1339+
}
1340+
1341+
// traffic to our NetBird IP, not netstack mode - send to netstack listeners
1342+
return false
1343+
}
1344+
1345+
// hasMatchingNetstackService checks if there's a registered netstack service for this packet
1346+
func (m *Manager) hasMatchingNetstackService(d *decoder) bool {
1347+
if len(d.decoded) < 2 {
1348+
return false
1349+
}
1350+
1351+
var dstPort uint16
1352+
switch d.decoded[1] {
1353+
case layers.LayerTypeTCP:
1354+
dstPort = uint16(d.tcp.DstPort)
1355+
case layers.LayerTypeUDP:
1356+
dstPort = uint16(d.udp.DstPort)
1357+
default:
1358+
return false
1359+
}
1360+
1361+
key := serviceKey{protocol: d.decoded[1], port: dstPort}
1362+
m.netstackServiceMutex.RLock()
1363+
_, exists := m.netstackServices[key]
1364+
m.netstackServiceMutex.RUnlock()
1365+
1366+
return exists
1367+
}

client/internal/dnsfwd/cache_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,3 @@ func TestCacheMiss(t *testing.T) {
8383
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
8484
}
8585
}
86-

client/internal/dnsfwd/forwarder.go

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/hashicorp/go-multierror"
1515
"github.com/miekg/dns"
1616
log "github.com/sirupsen/logrus"
17+
"golang.zx2c4.com/wireguard/tun/netstack"
1718

1819
nberrors "github.com/netbirdio/netbird/client/errors"
1920
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -33,7 +34,7 @@ type firewaller interface {
3334
}
3435

3536
type DNSForwarder struct {
36-
listenAddress string
37+
listenAddress netip.AddrPort
3738
ttl uint32
3839
statusRecorder *peer.Status
3940

@@ -47,9 +48,11 @@ type DNSForwarder struct {
4748
firewall firewaller
4849
resolver resolver
4950
cache *cache
51+
52+
wgIface wgIface
5053
}
5154

52-
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
55+
func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder {
5356
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
5457
return &DNSForwarder{
5558
listenAddress: listenAddress,
@@ -58,49 +61,80 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
5861
statusRecorder: statusRecorder,
5962
resolver: net.DefaultResolver,
6063
cache: newCache(),
64+
wgIface: wgIface,
6165
}
6266
}
6367

6468
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
65-
log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
69+
var netstackNet *netstack.Net
70+
if f.wgIface != nil {
71+
netstackNet = f.wgIface.GetNet()
72+
}
73+
74+
addrDesc := f.listenAddress.String()
75+
if netstackNet != nil {
76+
addrDesc = fmt.Sprintf("netstack %s", f.listenAddress)
77+
}
78+
log.Infof("starting DNS forwarder on address=%s", addrDesc)
79+
80+
udpLn, err := f.createUDPListener(netstackNet)
81+
if err != nil {
82+
return fmt.Errorf("create UDP listener: %w", err)
83+
}
84+
85+
tcpLn, err := f.createTCPListener(netstackNet)
86+
if err != nil {
87+
return fmt.Errorf("create TCP listener: %w", err)
88+
}
6689

67-
// UDP server
6890
mux := dns.NewServeMux()
6991
f.mux = mux
7092
mux.HandleFunc(".", f.handleDNSQueryUDP)
7193
f.dnsServer = &dns.Server{
72-
Addr: f.listenAddress,
73-
Net: "udp",
74-
Handler: mux,
94+
PacketConn: udpLn,
95+
Handler: mux,
7596
}
7697

77-
// TCP server
7898
tcpMux := dns.NewServeMux()
7999
f.tcpMux = tcpMux
80100
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
81101
f.tcpServer = &dns.Server{
82-
Addr: f.listenAddress,
83-
Net: "tcp",
84-
Handler: tcpMux,
102+
Listener: tcpLn,
103+
Handler: tcpMux,
85104
}
86105

87106
f.UpdateDomains(entries)
88107

89108
errCh := make(chan error, 2)
90109

91110
go func() {
92-
log.Infof("DNS UDP listener running on %s", f.listenAddress)
93-
errCh <- f.dnsServer.ListenAndServe()
111+
log.Infof("DNS UDP listener running on %s", addrDesc)
112+
errCh <- f.dnsServer.ActivateAndServe()
94113
}()
95114
go func() {
96-
log.Infof("DNS TCP listener running on %s", f.listenAddress)
97-
errCh <- f.tcpServer.ListenAndServe()
115+
log.Infof("DNS TCP listener running on %s", addrDesc)
116+
errCh <- f.tcpServer.ActivateAndServe()
98117
}()
99118

100-
// return the first error we get (e.g. bind failure or shutdown)
101119
return <-errCh
102120
}
103121

122+
func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) {
123+
if netstackNet != nil {
124+
return netstackNet.ListenUDPAddrPort(f.listenAddress)
125+
}
126+
127+
return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress))
128+
}
129+
130+
func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) {
131+
if netstackNet != nil {
132+
return netstackNet.ListenTCPAddrPort(f.listenAddress)
133+
}
134+
135+
return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress))
136+
}
137+
104138
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
105139
f.mutex.Lock()
106140
defer f.mutex.Unlock()

client/internal/dnsfwd/forwarder_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
297297
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
298298
}
299299

300-
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
300+
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
301301
forwarder.resolver = mockResolver
302302

303303
d, err := domain.FromString(tt.configuredDomain)
@@ -402,7 +402,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
402402
mockResolver := &MockResolver{}
403403

404404
// Set up forwarder
405-
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
405+
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
406406
forwarder.resolver = mockResolver
407407

408408
// Create entries and track sets
@@ -489,7 +489,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
489489
mockFirewall := &MockFirewall{}
490490
mockResolver := &MockResolver{}
491491

492-
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
492+
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
493493
forwarder.resolver = mockResolver
494494

495495
// Configure a single domain
@@ -584,7 +584,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
584584

585585
for _, tt := range tests {
586586
t.Run(tt.name, func(t *testing.T) {
587-
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
587+
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
588588

589589
d, err := domain.FromString(tt.configured)
590590
require.NoError(t, err)
@@ -616,7 +616,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
616616
func TestDNSForwarder_TCPTruncation(t *testing.T) {
617617
// Test that large UDP responses are truncated with TC bit set
618618
mockResolver := &MockResolver{}
619-
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
619+
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
620620
forwarder.resolver = mockResolver
621621

622622
d, _ := domain.FromString("example.com")
@@ -652,7 +652,7 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
652652
// a subsequent upstream failure still returns a successful response from cache.
653653
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
654654
mockResolver := &MockResolver{}
655-
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
655+
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
656656
forwarder.resolver = mockResolver
657657

658658
d, err := domain.FromString("example.com")
@@ -696,7 +696,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
696696
// Verifies that cache normalization works across casing and trailing dot variations.
697697
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
698698
mockResolver := &MockResolver{}
699-
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
699+
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
700700
forwarder.resolver = mockResolver
701701

702702
d, err := domain.FromString("ExAmPlE.CoM")
@@ -742,7 +742,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
742742
mockFirewall := &MockFirewall{}
743743
mockResolver := &MockResolver{}
744744

745-
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
745+
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
746746
forwarder.resolver = mockResolver
747747

748748
// Set up complex overlapping patterns
@@ -804,7 +804,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
804804
mockFirewall := &MockFirewall{}
805805
mockResolver := &MockResolver{}
806806

807-
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
807+
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
808808
forwarder.resolver = mockResolver
809809

810810
d, err := domain.FromString("example.com")
@@ -925,7 +925,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
925925

926926
func TestDNSForwarder_EmptyQuery(t *testing.T) {
927927
// Test handling of malformed query with no questions
928-
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
928+
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
929929

930930
query := &dns.Msg{}
931931
// Don't set any question

0 commit comments

Comments
 (0)