Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/internal/routemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ type DefaultManager struct {
func NewManager(config ManagerConfig) *DefaultManager {
mCTX, cancel := context.WithCancel(config.Context)
notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
sysOps := systemops.New(config.WGInterface, notifier)

if runtime.GOOS == "windows" && config.WGInterface != nil {
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
Expand Down
8 changes: 8 additions & 0 deletions client/internal/routemanager/systemops/flush_nonbsd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd)

package systemops

// FlushMarkedRoutes is a no-op on non-BSD platforms.
func (r *SysOps) FlushMarkedRoutes() error {
return nil
}
8 changes: 4 additions & 4 deletions client/internal/routemanager/systemops/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string {
}

func (s *ShutdownState) Cleanup() error {
sysops := NewSysOps(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData((*ExclusionCounter)(s))
sysOps := New(nil, nil)
sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable)
sysOps.refCounter.LoadData((*ExclusionCounter)(s))

return sysops.refCounter.Flush()
return sysOps.refCounter.Flush()
}

func (s *ShutdownState) MarshalJSON() ([]byte, error) {
Expand Down
2 changes: 1 addition & 1 deletion client/internal/routemanager/systemops/systemops.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ type SysOps struct {
localSubnetsCacheTime time.Time
}

func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{
wgInterface: wgInterface,
notifier: notifier,
Expand Down
4 changes: 2 additions & 2 deletions client/internal/routemanager/systemops/systemops_bsd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) {
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}

r := NewSysOps(nil, nil)
r := New(nil, nil)

var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
Expand Down Expand Up @@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin

nexthop := Nexthop{netip.Addr{}, netIntf}

r := NewSysOps(nil, nil)
r := New(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {

wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)

r := NewSysOps(wgInterface, nil)
r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err)
Expand Down Expand Up @@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {

wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)

r := NewSysOps(wgInterface, nil)
r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err)
Expand Down Expand Up @@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, wgInterface.Close())
})

r := NewSysOps(wgInterface, nil)
r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err, "setupRouting should not return err")
Expand Down
78 changes: 77 additions & 1 deletion client/internal/routemanager/systemops/systemops_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,39 @@ import (
"fmt"
"net"
"net/netip"
"os"
"strconv"
"syscall"
"time"
"unsafe"

"github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"

nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager"
)

const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
)

var routeProtoFlag int

func init() {
switch os.Getenv(envRouteProtoFlag) {
case "2":
routeProtoFlag = unix.RTF_PROTO2
case "3":
routeProtoFlag = unix.RTF_PROTO3
default:
routeProtoFlag = unix.RTF_PROTO1
}
}

func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
return r.setupRefCounter(initAddresses, stateManager)
}
Expand All @@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout
return r.cleanupRefCounter(stateManager)
}

// FlushMarkedRoutes removes all routes marked with the configured RTF_PROTO flag.
func (r *SysOps) FlushMarkedRoutes() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}

msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}

var merr *multierror.Error
flushedCount := 0

for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}

if rtMsg.Flags&routeProtoFlag == 0 {
continue
}

routeInfo, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("Skipping route flush: %v", err)
continue
}

if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() {
continue
}

nexthop := Nexthop{
IP: routeInfo.Gw,
Intf: routeInfo.Interface,
}

if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err))
continue
}

flushedCount++
log.Debugf("Flushed marked route: %s", routeInfo.Dst)
}

if flushedCount > 0 {
log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount)
}

return nberrors.FormatErrorOrNil(merr)
}

func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
}
Expand Down Expand Up @@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func(
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
msg = &route.RouteMessage{
Type: action,
Flags: unix.RTF_UP,
Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION,
Seq: r.getSeq(),
}
Expand Down
2 changes: 1 addition & 1 deletion client/internal/statemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage,
data, err := os.ReadFile(m.filePath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
log.Debug("state file does not exist")
log.Debugf("state file %s does not exist", m.filePath)
return nil, nil // nolint:nilnil
}
return nil, fmt.Errorf("read state file: %w", err)
Expand Down
9 changes: 9 additions & 0 deletions client/server/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (

nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto"
)

Expand Down Expand Up @@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error {
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
}

// clean up any remaining routes independently of the state file
if !nbnet.AdvancedRouting() {
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
}

return nberrors.FormatErrorOrNil(merr)
}
Loading