Skip to content

Commit c92e6c1

Browse files
authored
[client] Block on all subsystems on shutdown (#4709)
1 parent 641eb51 commit c92e6c1

File tree

7 files changed

+139
-65
lines changed

7 files changed

+139
-65
lines changed

client/internal/connect.go

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/netbirdio/netbird/client/internal/peer"
2626
"github.com/netbirdio/netbird/client/internal/profilemanager"
2727
"github.com/netbirdio/netbird/client/internal/stdnet"
28+
nbnet "github.com/netbirdio/netbird/client/net"
2829
cProto "github.com/netbirdio/netbird/client/proto"
2930
"github.com/netbirdio/netbird/client/ssh"
3031
"github.com/netbirdio/netbird/client/system"
@@ -34,7 +35,6 @@ import (
3435
relayClient "github.com/netbirdio/netbird/shared/relay/client"
3536
signal "github.com/netbirdio/netbird/shared/signal/client"
3637
"github.com/netbirdio/netbird/util"
37-
nbnet "github.com/netbirdio/netbird/client/net"
3838
"github.com/netbirdio/netbird/version"
3939
)
4040

@@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
289289
}
290290

291291
<-engineCtx.Done()
292+
292293
c.engineMutex.Lock()
293-
if c.engine != nil && c.engine.wgInterface != nil {
294-
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
295-
if err := c.engine.Stop(); err != nil {
294+
engine := c.engine
295+
c.engine = nil
296+
c.engineMutex.Unlock()
297+
298+
if engine != nil && engine.wgInterface != nil {
299+
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
300+
if err := engine.Stop(); err != nil {
296301
log.Errorf("Failed to stop engine: %v", err)
297302
}
298-
c.engine = nil
299303
}
300-
c.engineMutex.Unlock()
301304
c.statusRecorder.ClientTeardown()
302305

303306
backOff.Reset()
@@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType {
382385
}
383386

384387
func (c *ConnectClient) Stop() error {
385-
if c == nil {
386-
return nil
387-
}
388-
c.engineMutex.Lock()
389-
defer c.engineMutex.Unlock()
390-
391-
if c.engine == nil {
392-
return nil
393-
}
394-
if err := c.engine.Stop(); err != nil {
395-
return fmt.Errorf("stop engine: %w", err)
388+
engine := c.Engine()
389+
if engine != nil {
390+
if err := engine.Stop(); err != nil {
391+
return fmt.Errorf("stop engine: %w", err)
392+
}
396393
}
397-
398394
return nil
399395
}
400396

client/internal/dns/server.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface {
6565

6666
// DefaultServer dns server object
6767
type DefaultServer struct {
68-
ctx context.Context
69-
ctxCancel context.CancelFunc
68+
ctx context.Context
69+
ctxCancel context.CancelFunc
70+
shutdownWg sync.WaitGroup
7071
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
7172
// This is different from ServiceEnable=false from management which completely disables the DNS service.
7273
disableSys bool
@@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr {
318319
// Stop stops the server
319320
func (s *DefaultServer) Stop() {
320321
s.ctxCancel()
322+
s.shutdownWg.Wait()
321323

322324
s.mux.Lock()
323325
defer s.mux.Unlock()
@@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
507509

508510
s.applyHostConfig()
509511

512+
s.shutdownWg.Add(1)
510513
go func() {
511-
// persist dns state right away
514+
defer s.shutdownWg.Done()
512515
if err := s.stateManager.PersistState(s.ctx); err != nil {
513516
log.Errorf("Failed to persist dns state: %v", err)
514517
}

client/internal/engine.go

Lines changed: 90 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ type Engine struct {
148148

149149
// syncMsgMux is used to guarantee sequential Management Service message processing
150150
syncMsgMux *sync.Mutex
151+
// sshMux protects sshServer field access
152+
sshMux sync.Mutex
151153

152154
config *EngineConfig
153155
mobileDep MobileDependency
@@ -200,8 +202,10 @@ type Engine struct {
200202
flowManager nftypes.FlowManager
201203

202204
// WireGuard interface monitor
203-
wgIfaceMonitor *WGIfaceMonitor
204-
wgIfaceMonitorWg sync.WaitGroup
205+
wgIfaceMonitor *WGIfaceMonitor
206+
207+
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
208+
shutdownWg sync.WaitGroup
205209

206210
probeStunTurn *relay.StunTurnProbe
207211
}
@@ -320,19 +324,13 @@ func (e *Engine) Stop() error {
320324
e.cancel()
321325
}
322326

323-
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
324-
// Removing peers happens in the conn.Close() asynchronously
325-
time.Sleep(500 * time.Millisecond)
326-
327327
e.close()
328328

329329
// stop flow manager after wg interface is gone
330330
if e.flowManager != nil {
331331
e.flowManager.Close()
332332
}
333333

334-
log.Infof("stopped Netbird Engine")
335-
336334
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
337335
defer cancel()
338336

@@ -343,12 +341,52 @@ func (e *Engine) Stop() error {
343341
log.Errorf("failed to persist state: %v", err)
344342
}
345343

346-
// Stop WireGuard interface monitor and wait for it to exit
347-
e.wgIfaceMonitorWg.Wait()
344+
timeout := e.calculateShutdownTimeout()
345+
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
346+
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
347+
defer cancel()
348+
349+
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
350+
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
351+
}
352+
353+
log.Infof("stopped Netbird Engine")
348354

349355
return nil
350356
}
351357

358+
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
359+
func (e *Engine) calculateShutdownTimeout() time.Duration {
360+
peerCount := len(e.peerStore.PeersPubKey())
361+
362+
baseTimeout := 10 * time.Second
363+
perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond
364+
timeout := baseTimeout + perPeerTimeout
365+
366+
maxTimeout := 30 * time.Second
367+
if timeout > maxTimeout {
368+
timeout = maxTimeout
369+
}
370+
371+
return timeout
372+
}
373+
374+
// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout.
375+
func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
376+
done := make(chan struct{})
377+
go func() {
378+
wg.Wait()
379+
close(done)
380+
}()
381+
382+
select {
383+
case <-done:
384+
return nil
385+
case <-ctx.Done():
386+
return ctx.Err()
387+
}
388+
}
389+
352390
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
353391
// Connections to remote peers are not established here.
354392
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
@@ -478,14 +516,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
478516

479517
// monitor WireGuard interface lifecycle and restart engine on changes
480518
e.wgIfaceMonitor = NewWGIfaceMonitor()
481-
e.wgIfaceMonitorWg.Add(1)
519+
e.shutdownWg.Add(1)
482520

483521
go func() {
484-
defer e.wgIfaceMonitorWg.Done()
522+
defer e.shutdownWg.Done()
485523

486524
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
487525
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
488-
e.restartEngine()
526+
e.triggerClientRestart()
489527
} else if err != nil {
490528
log.Warnf("WireGuard interface monitor: %s", err)
491529
}
@@ -669,9 +707,11 @@ func (e *Engine) removeAllPeers() error {
669707
func (e *Engine) removePeer(peerKey string) error {
670708
log.Debugf("removing peer from engine %s", peerKey)
671709

710+
e.sshMux.Lock()
672711
if !isNil(e.sshServer) {
673712
e.sshServer.RemoveAuthorizedKey(peerKey)
674713
}
714+
e.sshMux.Unlock()
675715

676716
e.connMgr.RemovePeerConn(peerKey)
677717

@@ -873,41 +913,50 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
873913
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
874914
return nil
875915
}
916+
e.sshMux.Lock()
876917
// start SSH server if it wasn't running
877918
if isNil(e.sshServer) {
878919
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
879920
if nbnetstack.IsEnabled() {
880921
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
881922
}
882923
// nil sshServer means it has not yet been started
883-
var err error
884-
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
885-
924+
server, err := e.sshServerFunc(e.config.SSHKey, listenAddr)
886925
if err != nil {
926+
e.sshMux.Unlock()
887927
return fmt.Errorf("create ssh server: %w", err)
888928
}
929+
930+
e.sshServer = server
931+
e.sshMux.Unlock()
932+
889933
go func() {
890934
// blocking
891-
err = e.sshServer.Start()
935+
err = server.Start()
892936
if err != nil {
893937
// will throw error when we stop it even if it is a graceful stop
894938
log.Debugf("stopped SSH server with error %v", err)
895939
}
896-
e.syncMsgMux.Lock()
897-
defer e.syncMsgMux.Unlock()
940+
e.sshMux.Lock()
898941
e.sshServer = nil
942+
e.sshMux.Unlock()
899943
log.Infof("stopped SSH server")
900944
}()
901945
} else {
946+
e.sshMux.Unlock()
902947
log.Debugf("SSH server is already running")
903948
}
904-
} else if !isNil(e.sshServer) {
905-
// Disable SSH server request, so stop it if it was running
906-
err := e.sshServer.Stop()
907-
if err != nil {
908-
log.Warnf("failed to stop SSH server %v", err)
949+
} else {
950+
e.sshMux.Lock()
951+
if !isNil(e.sshServer) {
952+
// Disable SSH server request, so stop it if it was running
953+
err := e.sshServer.Stop()
954+
if err != nil {
955+
log.Warnf("failed to stop SSH server %v", err)
956+
}
957+
e.sshServer = nil
909958
}
910-
e.sshServer = nil
959+
e.sshMux.Unlock()
911960
}
912961
return nil
913962
}
@@ -944,7 +993,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
944993
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
945994
// E.g. when a new peer has been registered and we are allowed to connect to it.
946995
func (e *Engine) receiveManagementEvents() {
996+
e.shutdownWg.Add(1)
947997
go func() {
998+
defer e.shutdownWg.Done()
948999
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
9491000
if err != nil {
9501001
log.Warnf("failed to get system info with checks: %v", err)
@@ -1120,6 +1171,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
11201171
e.statusRecorder.FinishPeerListModifications()
11211172

11221173
// update SSHServer by adding remote peer SSH keys
1174+
e.sshMux.Lock()
11231175
if !isNil(e.sshServer) {
11241176
for _, config := range networkMap.GetRemotePeers() {
11251177
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
@@ -1130,6 +1182,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
11301182
}
11311183
}
11321184
}
1185+
e.sshMux.Unlock()
11331186
}
11341187

11351188
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
@@ -1372,7 +1425,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
13721425

13731426
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
13741427
func (e *Engine) receiveSignalEvents() {
1428+
e.shutdownWg.Add(1)
13751429
go func() {
1430+
defer e.shutdownWg.Done()
13761431
// connect to a stream of messages coming from the signal server
13771432
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
13781433
e.syncMsgMux.Lock()
@@ -1489,12 +1544,14 @@ func (e *Engine) close() {
14891544
e.statusRecorder.SetWgIface(nil)
14901545
}
14911546

1547+
e.sshMux.Lock()
14921548
if !isNil(e.sshServer) {
14931549
err := e.sshServer.Stop()
14941550
if err != nil {
14951551
log.Warnf("failed stopping the SSH server: %v", err)
14961552
}
14971553
}
1554+
e.sshMux.Unlock()
14981555

14991556
if e.firewall != nil {
15001557
err := e.firewall.Close(e.stateManager)
@@ -1725,8 +1782,10 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
17251782
return allHealthy
17261783
}
17271784

1728-
// restartEngine restarts the engine by cancelling the client context
1729-
func (e *Engine) restartEngine() {
1785+
// triggerClientRestart triggers a full client restart by cancelling the client context.
1786+
// Note: This does NOT just restart the engine - it cancels the entire client context,
1787+
// which causes the connect client's retry loop to create a completely new engine.
1788+
func (e *Engine) triggerClientRestart() {
17301789
e.syncMsgMux.Lock()
17311790
defer e.syncMsgMux.Unlock()
17321791

@@ -1748,7 +1807,9 @@ func (e *Engine) startNetworkMonitor() {
17481807
}
17491808

17501809
e.networkMonitor = networkmonitor.New()
1810+
e.shutdownWg.Add(1)
17511811
go func() {
1812+
defer e.shutdownWg.Done()
17521813
if err := e.networkMonitor.Listen(e.ctx); err != nil {
17531814
if errors.Is(err, context.Canceled) {
17541815
log.Infof("network monitor stopped")
@@ -1758,8 +1819,8 @@ func (e *Engine) startNetworkMonitor() {
17581819
return
17591820
}
17601821

1761-
log.Infof("Network monitor: detected network change, restarting engine")
1762-
e.restartEngine()
1822+
log.Infof("Network monitor: detected network change, triggering client restart")
1823+
e.triggerClientRestart()
17631824
}()
17641825
}
17651826

0 commit comments

Comments
 (0)