Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions client/internal/dns/host_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strings"

log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"

"github.com/netbirdio/netbird/client/internal/statemanager"
)
Expand Down Expand Up @@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool {
}

func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error

if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}

var (
searchDomains []string
matchDomains []string
)

err = s.recordSystemDNSSettings(true)
if err != nil {
if err := s.recordSystemDNSSettings(true); err != nil {
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
}

if config.RouteAll {
searchDomains = append(searchDomains, "\"\"")
err = s.addLocalDNS()
if err != nil {
log.Infof("failed to enable split DNS")
if err := s.addLocalDNS(); err != nil {
log.Warnf("failed to add local DNS: %v", err)
}
s.updateState(stateManager)
}

for _, dConf := range config.Domains {
Expand All @@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
}

matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
var err error
if len(matchDomains) != 0 {
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
} else {
Expand All @@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
if err != nil {
return fmt.Errorf("add match domains: %w", err)
}
s.updateState(stateManager)

searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
if len(searchDomains) != 0 {
Expand All @@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
if err != nil {
return fmt.Errorf("add search domains: %w", err)
}
s.updateState(stateManager)

if err := s.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
Expand All @@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
return nil
}

func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
}

func (s *systemConfigurator) string() string {
return "scutil"
}
Expand Down Expand Up @@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
func (s *systemConfigurator) addLocalDNS() error {
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
if err := s.recordSystemDNSSettings(true); err != nil {
log.Errorf("Unable to get system DNS configuration")
return fmt.Errorf("recordSystemDNSSettings(): %w", err)
}
}
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 {
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
if err != nil {
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
}
} else {
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
log.Info("Not enabling local DNS server")
return nil
}

if err := s.addSearchDomains(
localKey,
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
); err != nil {
return fmt.Errorf("add search domains: %w", err)
}

return nil
Expand Down
111 changes: 111 additions & 0 deletions client/internal/dns/host_darwin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
//go:build !ios

package dns

import (
"context"
"net/netip"
"os/exec"
"path/filepath"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/netbirdio/netbird/client/internal/statemanager"
)

func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
if testing.Short() {
t.Skip("skipping scutil integration test in short mode")
}

tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")

sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
defer func() {
require.NoError(t, sm.Stop(context.Background()))
}()

configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}

config := HostDNSConfig{
ServerIP: netip.MustParseAddr("100.64.0.1"),
ServerPort: 53,
RouteAll: true,
Domains: []DomainConfig{
{Domain: "example.com", MatchOnly: true},
},
}

err := configurator.applyDNSConfig(config, sm)
require.NoError(t, err)

require.NoError(t, sm.PersistState(context.Background()))

searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)

defer func() {
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
}()

for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
if exists {
t.Logf("Key %s exists before cleanup", key)
}
}

sm2 := statemanager.New(stateFile)
sm2.RegisterState(&ShutdownState{})
err = sm2.LoadState(&ShutdownState{})
require.NoError(t, err)

state := sm2.GetState(&ShutdownState{})
if state == nil {
t.Skip("State not saved, skipping cleanup test")
}

shutdownState, ok := state.(*ShutdownState)
require.True(t, ok)

err = shutdownState.Cleanup()
require.NoError(t, err)

for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
}
}

func checkDNSKeyExists(key string) (bool, error) {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
output, err := cmd.CombinedOutput()
if err != nil {
if strings.Contains(string(output), "No such key") {
return false, nil
}
return false, err
}
return !strings.Contains(string(output), "No such key"), nil
}

func removeTestDNSKey(key string) error {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n")
_, err := cmd.CombinedOutput()
return err
}
26 changes: 12 additions & 14 deletions client/internal/dns/host_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
}

if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
r.updateState(stateManager)

var searchDomains, matchDomains []string
for _, dConf := range config.Domains {
Expand All @@ -212,13 +206,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
r.nrptEntryCount = 0
}

if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
r.updateState(stateManager)

if err := r.updateSearchDomains(searchDomains); err != nil {
return fmt.Errorf("update search domains: %w", err)
Expand All @@ -229,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return nil
}

func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
}

func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
return fmt.Errorf("adding dns setup for all failed: %w", err)
Expand Down
5 changes: 5 additions & 0 deletions client/internal/dns/unclean_shutdown_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
)

type ShutdownState struct {
CreatedKeys []string
}

func (s *ShutdownState) Name() string {
Expand All @@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create host manager: %w", err)
}

for _, key := range s.CreatedKeys {
manager.createdKeys[key] = struct{}{}
}

if err := manager.restoreUncleanShutdownDNS(); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err)
}
Expand Down
Loading