Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 94 additions & 3 deletions client/firewall/uspfilter/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ const (

var errNatNotSupported = errors.New("nat not supported with userspace firewall")

// serviceKey represents a protocol/port combination for netstack service registry
type serviceKey struct {
protocol gopacket.LayerType
port uint16
}

// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule

Expand Down Expand Up @@ -113,6 +119,9 @@ type Manager struct {
portDNATEnabled atomic.Bool
portDNATRules []portDNATRule
portDNATMutex sync.RWMutex

netstackServices map[serviceKey]struct{}
netstackServiceMutex sync.RWMutex
}

// decoder for packages
Expand Down Expand Up @@ -203,6 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
}
m.routingEnabled.Store(false)

Expand Down Expand Up @@ -838,9 +848,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return true
}

// If requested we pass local traffic to internal interfaces to the forwarder.
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
if m.shouldForward(d, dstIP) {
return m.handleForwardedLocalTraffic(packetData)
}

Expand Down Expand Up @@ -1274,3 +1282,86 @@ func (m *Manager) DisableRouting() error {

return nil
}

// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) {
m.netstackServiceMutex.Lock()
defer m.netstackServiceMutex.Unlock()
layerType := m.protocolToLayerType(protocol)
key := serviceKey{protocol: layerType, port: port}
m.netstackServices[key] = struct{}{}
m.logger.Debug3("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType)
m.logger.Debug1("RegisterNetstackService: current registry size: %d", len(m.netstackServices))
}

// UnregisterNetstackService removes a service from the netstack registry
func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) {
m.netstackServiceMutex.Lock()
defer m.netstackServiceMutex.Unlock()
layerType := m.protocolToLayerType(protocol)
key := serviceKey{protocol: layerType, port: port}
delete(m.netstackServices, key)
m.logger.Debug2("Unregistered netstack service on protocol %s port %d", protocol, port)
}

// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use
func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType {
switch protocol {
case nftypes.TCP:
return layers.LayerTypeTCP
case nftypes.UDP:
return layers.LayerTypeUDP
case nftypes.ICMP:
return layers.LayerTypeICMPv4
default:
return gopacket.LayerType(0) // Invalid/unknown
}
}

// shouldForward determines if a packet should be forwarded to the forwarder.
// The forwarder handles routing packets to the native OS network stack.
// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly.
func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
// not enabled, never forward
if !m.localForwarding {
return false
}

// netstack always needs to forward because it's lacking a native interface
// exception for registered netstack services, those should go to netstack listeners
if m.netstack {
return !m.hasMatchingNetstackService(d)
}

// traffic to our other local interfaces (not NetBird IP) - always forward
if dstIP != m.wgIface.Address().IP {
return true
}

// traffic to our NetBird IP, not netstack mode - send to netstack listeners
return false
}

// hasMatchingNetstackService checks if there's a registered netstack service for this packet
func (m *Manager) hasMatchingNetstackService(d *decoder) bool {
if len(d.decoded) < 2 {
return false
}

var dstPort uint16
switch d.decoded[1] {
case layers.LayerTypeTCP:
dstPort = uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
dstPort = uint16(d.udp.DstPort)
default:
return false
}

key := serviceKey{protocol: d.decoded[1], port: dstPort}
m.netstackServiceMutex.RLock()
_, exists := m.netstackServices[key]
m.netstackServiceMutex.RUnlock()

return exists
}
1 change: 0 additions & 1 deletion client/internal/dnsfwd/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,3 @@ func TestCacheMiss(t *testing.T) {
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
}
}

66 changes: 50 additions & 16 deletions client/internal/dnsfwd/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"

nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
Expand All @@ -33,7 +34,7 @@ type firewaller interface {
}

type DNSForwarder struct {
listenAddress string
listenAddress netip.AddrPort
ttl uint32
statusRecorder *peer.Status

Expand All @@ -47,9 +48,11 @@ type DNSForwarder struct {
firewall firewaller
resolver resolver
cache *cache

wgIface wgIface
}

func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{
listenAddress: listenAddress,
Expand All @@ -58,49 +61,80 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
statusRecorder: statusRecorder,
resolver: net.DefaultResolver,
cache: newCache(),
wgIface: wgIface,
}
}

func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
var netstackNet *netstack.Net
if f.wgIface != nil {
netstackNet = f.wgIface.GetNet()
}

addrDesc := f.listenAddress.String()
if netstackNet != nil {
addrDesc = fmt.Sprintf("netstack %s", f.listenAddress)
}
log.Infof("starting DNS forwarder on address=%s", addrDesc)

udpLn, err := f.createUDPListener(netstackNet)
if err != nil {
return fmt.Errorf("create UDP listener: %w", err)
}

tcpLn, err := f.createTCPListener(netstackNet)
if err != nil {
return fmt.Errorf("create TCP listener: %w", err)
}

// UDP server
mux := dns.NewServeMux()
f.mux = mux
mux.HandleFunc(".", f.handleDNSQueryUDP)
f.dnsServer = &dns.Server{
Addr: f.listenAddress,
Net: "udp",
Handler: mux,
PacketConn: udpLn,
Handler: mux,
}

// TCP server
tcpMux := dns.NewServeMux()
f.tcpMux = tcpMux
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
f.tcpServer = &dns.Server{
Addr: f.listenAddress,
Net: "tcp",
Handler: tcpMux,
Listener: tcpLn,
Handler: tcpMux,
}

f.UpdateDomains(entries)

errCh := make(chan error, 2)

go func() {
log.Infof("DNS UDP listener running on %s", f.listenAddress)
errCh <- f.dnsServer.ListenAndServe()
log.Infof("DNS UDP listener running on %s", addrDesc)
errCh <- f.dnsServer.ActivateAndServe()
}()
go func() {
log.Infof("DNS TCP listener running on %s", f.listenAddress)
errCh <- f.tcpServer.ListenAndServe()
log.Infof("DNS TCP listener running on %s", addrDesc)
errCh <- f.tcpServer.ActivateAndServe()
}()

// return the first error we get (e.g. bind failure or shutdown)
return <-errCh
}

func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) {
if netstackNet != nil {
return netstackNet.ListenUDPAddrPort(f.listenAddress)
}

return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress))
}

func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) {
if netstackNet != nil {
return netstackNet.ListenTCPAddrPort(f.listenAddress)
}

return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress))
}

func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock()
defer f.mutex.Unlock()
Expand Down
20 changes: 10 additions & 10 deletions client/internal/dnsfwd/forwarder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
}

forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
forwarder.resolver = mockResolver

d, err := domain.FromString(tt.configuredDomain)
Expand Down Expand Up @@ -402,7 +402,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
mockResolver := &MockResolver{}

// Set up forwarder
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
forwarder.resolver = mockResolver

// Create entries and track sets
Expand Down Expand Up @@ -489,7 +489,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}

forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
forwarder.resolver = mockResolver

// Configure a single domain
Expand Down Expand Up @@ -584,7 +584,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)

d, err := domain.FromString(tt.configured)
require.NoError(t, err)
Expand Down Expand Up @@ -616,7 +616,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = mockResolver

d, _ := domain.FromString("example.com")
Expand Down Expand Up @@ -652,7 +652,7 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
// a subsequent upstream failure still returns a successful response from cache.
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = mockResolver

d, err := domain.FromString("example.com")
Expand Down Expand Up @@ -696,7 +696,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
// Verifies that cache normalization works across casing and trailing dot variations.
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = mockResolver

d, err := domain.FromString("ExAmPlE.CoM")
Expand Down Expand Up @@ -742,7 +742,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}

forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
forwarder.resolver = mockResolver

// Set up complex overlapping patterns
Expand Down Expand Up @@ -804,7 +804,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}

forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
forwarder.resolver = mockResolver

d, err := domain.FromString("example.com")
Expand Down Expand Up @@ -925,7 +925,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {

func TestDNSForwarder_EmptyQuery(t *testing.T) {
// Test handling of malformed query with no questions
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)

query := &dns.Msg{}
// Don't set any question
Expand Down
Loading
Loading