Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/internal/routemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ type DefaultManager struct {
func NewManager(config ManagerConfig) *DefaultManager {
mCTX, cancel := context.WithCancel(config.Context)
notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
sysOps := systemops.New(config.WGInterface, notifier)

if runtime.GOOS == "windows" && config.WGInterface != nil {
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
Expand Down
8 changes: 8 additions & 0 deletions client/internal/routemanager/systemops/flush_nonbsd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd)

package systemops

// FlushMarkedRoutes is a no-op on non-BSD platforms.
func (r *SysOps) FlushMarkedRoutes() error {
return nil
}
8 changes: 4 additions & 4 deletions client/internal/routemanager/systemops/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string {
}

func (s *ShutdownState) Cleanup() error {
sysops := NewSysOps(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData((*ExclusionCounter)(s))
sysOps := New(nil, nil)
sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable)
sysOps.refCounter.LoadData((*ExclusionCounter)(s))

return sysops.refCounter.Flush()
return sysOps.refCounter.Flush()
}

func (s *ShutdownState) MarshalJSON() ([]byte, error) {
Expand Down
2 changes: 1 addition & 1 deletion client/internal/routemanager/systemops/systemops.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ type SysOps struct {
localSubnetsCacheTime time.Time
}

func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{
wgInterface: wgInterface,
notifier: notifier,
Expand Down
4 changes: 2 additions & 2 deletions client/internal/routemanager/systemops/systemops_bsd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) {
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}

r := NewSysOps(nil, nil)
r := New(nil, nil)

var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
Expand Down Expand Up @@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin

nexthop := Nexthop{netip.Addr{}, netIntf}

r := NewSysOps(nil, nil)
r := New(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {

wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)

r := NewSysOps(wgInterface, nil)
r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err)
Expand Down Expand Up @@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {

wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)

r := NewSysOps(wgInterface, nil)
r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err)
Expand Down Expand Up @@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, wgInterface.Close())
})

r := NewSysOps(wgInterface, nil)
r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err, "setupRouting should not return err")
Expand Down
78 changes: 77 additions & 1 deletion client/internal/routemanager/systemops/systemops_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,39 @@ import (
"fmt"
"net"
"net/netip"
"os"
"strconv"
"syscall"
"time"
"unsafe"

"github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"

nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager"
)

const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
)

var routeProtoFlag int

func init() {
switch os.Getenv(envRouteProtoFlag) {
case "2":
routeProtoFlag = unix.RTF_PROTO2
case "3":
routeProtoFlag = unix.RTF_PROTO3
default:
routeProtoFlag = unix.RTF_PROTO1
}
}

func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
return r.setupRefCounter(initAddresses, stateManager)
}
Expand All @@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout
return r.cleanupRefCounter(stateManager)
}

// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
func (r *SysOps) FlushMarkedRoutes() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}

msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}

var merr *multierror.Error
flushedCount := 0

for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}

if rtMsg.Flags&routeProtoFlag == 0 {
continue
}

routeInfo, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("Skipping route flush: %v", err)
continue
}

if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() {
continue
}

nexthop := Nexthop{
IP: routeInfo.Gw,
Intf: routeInfo.Interface,
}

if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err))
continue
}

flushedCount++
log.Debugf("Flushed marked route: %s", routeInfo.Dst)
}

if flushedCount > 0 {
log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount)
}

return nberrors.FormatErrorOrNil(merr)
}

func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
}
Expand Down Expand Up @@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func(
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
msg = &route.RouteMessage{
Type: action,
Flags: unix.RTF_UP,
Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION,
Seq: r.getSeq(),
}
Expand Down
2 changes: 1 addition & 1 deletion client/internal/statemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage,
data, err := os.ReadFile(m.filePath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
log.Debug("state file does not exist")
log.Debugf("state file %s does not exist", m.filePath)
return nil, nil // nolint:nilnil
}
return nil, fmt.Errorf("read state file: %w", err)
Expand Down
9 changes: 9 additions & 0 deletions client/server/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (

nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto"
)

Expand Down Expand Up @@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error {
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
}

// clean up any remaining routes independently of the state file
if !nbnet.AdvancedRouting() {
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
}

return nberrors.FormatErrorOrNil(merr)
}
Loading