diff --git a/internal/bus/topics.go b/internal/bus/topics.go index f342006d2..7ba30ac49 100644 --- a/internal/bus/topics.go +++ b/internal/bus/topics.go @@ -15,6 +15,8 @@ const ( ConfigUploadRequestTopic = "config-upload-request" DataPlaneResponseTopic = "data-plane-response" ConnectionCreatedTopic = "connection-created" + CredentialUpdatedTopic = "credential-updated" + ConnectionResetTopic = "connection-reset" ConfigApplyRequestTopic = "config-apply-request" WriteConfigSuccessfulTopic = "write-config-successful" ConfigApplySuccessfulTopic = "config-apply-successful" diff --git a/internal/command/command_plugin.go b/internal/command/command_plugin.go index 293adfe14..57c46b048 100644 --- a/internal/command/command_plugin.go +++ b/internal/command/command_plugin.go @@ -31,6 +31,7 @@ type ( UpdateDataPlaneStatus(ctx context.Context, resource *mpi.Resource) error UpdateDataPlaneHealth(ctx context.Context, instanceHealths []*mpi.InstanceHealth) error SendDataPlaneResponse(ctx context.Context, response *mpi.DataPlaneResponse) error + UpdateClient(client mpi.CommandServiceClient) Subscribe(ctx context.Context) IsConnected() bool CreateConnection(ctx context.Context, resource *mpi.Resource) (*mpi.CreateConnectionResponse, error) @@ -86,6 +87,8 @@ func (cp *CommandPlugin) Info() *bus.Info { func (cp *CommandPlugin) Process(ctx context.Context, msg *bus.Message) { switch msg.Topic { + case bus.ConnectionResetTopic: + cp.processConnectionReset(ctx, msg) case bus.ResourceUpdateTopic: cp.processResourceUpdate(ctx, msg) case bus.InstanceHealthTopic: @@ -172,8 +175,22 @@ func (cp *CommandPlugin) processDataPlaneResponse(ctx context.Context, msg *bus. } } +func (cp *CommandPlugin) processConnectionReset(ctx context.Context, msg *bus.Message) { + slog.DebugContext(ctx, "Command plugin received connection reset") + if newConnection, ok := msg.Data.(grpc.GrpcConnectionInterface); ok { + err := cp.conn.Close(ctx) + if err != nil { + slog.ErrorContext(ctx, "Command plugin: unable to close connection", "error", err) + } + cp.conn = newConnection + cp.commandService.UpdateClient(cp.conn.CommandServiceClient()) + slog.DebugContext(ctx, "Command service client reset successfully") + } +} + func (cp *CommandPlugin) Subscriptions() []string { return []string{ + bus.ConnectionResetTopic, bus.ResourceUpdateTopic, bus.InstanceHealthTopic, bus.DataPlaneHealthResponseTopic, diff --git a/internal/command/command_plugin_test.go b/internal/command/command_plugin_test.go index 90f56c394..2c532bf65 100644 --- a/internal/command/command_plugin_test.go +++ b/internal/command/command_plugin_test.go @@ -45,6 +45,7 @@ func TestCommandPlugin_Subscriptions(t *testing.T) { assert.Equal( t, []string{ + bus.ConnectionResetTopic, bus.ResourceUpdateTopic, bus.InstanceHealthTopic, bus.DataPlaneHealthResponseTopic, @@ -142,6 +143,12 @@ func TestCommandPlugin_Process(t *testing.T) { }) require.Equal(t, 1, fakeCommandService.UpdateDataPlaneHealthCallCount()) require.Equal(t, 1, fakeCommandService.SendDataPlaneResponseCallCount()) + + commandPlugin.Process(ctx, &bus.Message{ + Topic: bus.ConnectionResetTopic, + Data: commandPlugin.conn, + }) + require.Equal(t, 1, fakeCommandService.UpdateClientCallCount()) } func TestCommandPlugin_monitorSubscribeChannel(t *testing.T) { diff --git a/internal/command/command_service.go b/internal/command/command_service.go index 3c569e7f8..6b1787401 100644 --- a/internal/command/command_service.go +++ b/internal/command/command_service.go @@ -98,11 +98,13 @@ func (cs *CommandService) UpdateDataPlaneStatus( sendDataPlaneStatus := func() (*mpi.UpdateDataPlaneStatusResponse, error) { slog.DebugContext(ctx, "Sending data plane status update request", "request", request, "parent_correlation_id", correlationID) + + cs.subscribeClientMutex.Lock() if cs.commandServiceClient == nil { return nil, errors.New("command service client is not initialized") } - response, updateError := cs.commandServiceClient.UpdateDataPlaneStatus(ctx, request) + cs.subscribeClientMutex.Unlock() validatedError := grpc.ValidateGrpcError(updateError) if validatedError != nil { @@ -210,6 +212,10 @@ func (cs *CommandService) CreateConnection( slog.InfoContext(ctx, "No Data Plane Instance found") } + if cs.isConnected.Load() { + return nil, errors.New("command service already connected") + } + request := &mpi.CreateConnectionRequest{ MessageMeta: &mpi.MessageMeta{ MessageId: id.GenerateMessageID(), @@ -228,7 +234,6 @@ func (cs *CommandService) CreateConnection( } slog.DebugContext(ctx, "Sending create connection request", "request", request) - response, err := backoff.RetryWithData( cs.connectCallback(ctx, request), backoffHelpers.Context(ctx, commonSettings), @@ -249,6 +254,12 @@ func (cs *CommandService) CreateConnection( return response, nil } +func (cs *CommandService) UpdateClient(client mpi.CommandServiceClient) { + cs.subscribeClientMutex.Lock() + defer cs.subscribeClientMutex.Unlock() + cs.commandServiceClient = client +} + // Retry callback for sending a data plane response to the Management Plane. func (cs *CommandService) sendDataPlaneResponseCallback( ctx context.Context, @@ -355,11 +366,14 @@ func (cs *CommandService) dataPlaneHealthCallback( ) func() (*mpi.UpdateDataPlaneHealthResponse, error) { return func() (*mpi.UpdateDataPlaneHealthResponse, error) { slog.DebugContext(ctx, "Sending data plane health update request", "request", request) + + cs.subscribeClientMutex.Lock() if cs.commandServiceClient == nil { return nil, errors.New("command service client is not initialized") } response, updateError := cs.commandServiceClient.UpdateDataPlaneHealth(ctx, request) + cs.subscribeClientMutex.Unlock() validatedError := grpc.ValidateGrpcError(updateError) @@ -427,6 +441,7 @@ func (cs *CommandService) handleSubscribeError(ctx context.Context, err error, e codeError, ok := status.FromError(err) if ok && codeError.Code() == codes.Unavailable { + cs.isConnected.Store(false) slog.ErrorContext(ctx, fmt.Sprintf("Failed to %s, rpc unavailable. "+ "Trying create connection rpc", errorMsg), "error", err) _, connectionErr := cs.CreateConnection(ctx, cs.resource) @@ -530,7 +545,9 @@ func (cs *CommandService) connectCallback( request *mpi.CreateConnectionRequest, ) func() (*mpi.CreateConnectionResponse, error) { return func() (*mpi.CreateConnectionResponse, error) { + cs.subscribeClientMutex.Lock() response, connectErr := cs.commandServiceClient.CreateConnection(ctx, request) + cs.subscribeClientMutex.Unlock() validatedError := grpc.ValidateGrpcError(connectErr) if validatedError != nil { diff --git a/internal/command/command_service_test.go b/internal/command/command_service_test.go index 0dfb00e95..936e523d2 100644 --- a/internal/command/command_service_test.go +++ b/internal/command/command_service_test.go @@ -198,6 +198,18 @@ func TestCommandService_CreateConnection(t *testing.T) { require.NoError(t, err) } +func TestCommandService_UpdateClient(t *testing.T) { + commandServiceClient := &v1fakes.FakeCommandServiceClient{} + + commandService := NewCommandService( + commandServiceClient, + types.AgentConfig(), + make(chan *mpi.ManagementPlaneRequest), + ) + commandService.UpdateClient(commandServiceClient) + assert.NotNil(t, commandService.commandServiceClient) +} + func TestCommandService_UpdateDataPlaneHealth(t *testing.T) { ctx := context.Background() commandServiceClient := &v1fakes.FakeCommandServiceClient{} @@ -501,3 +513,18 @@ func TestCommandService_isValidRequest(t *testing.T) { }) } } + +func TestCommandService_handleSubscribeError(t *testing.T) { + ctx := context.Background() + commandServiceClient := &v1fakes.FakeCommandServiceClient{} + + commandService := NewCommandService( + commandServiceClient, + types.AgentConfig(), + make(chan *mpi.ManagementPlaneRequest), + ) + require.Error(t, + commandService.handleSubscribeError(ctx, + errors.New("an error occurred when attempting to subscribe"), + "Testing handleSubscribeError")) +} diff --git a/internal/command/commandfakes/fake_command_service.go b/internal/command/commandfakes/fake_command_service.go index 0748ce080..7bfaeb7c0 100644 --- a/internal/command/commandfakes/fake_command_service.go +++ b/internal/command/commandfakes/fake_command_service.go @@ -50,6 +50,11 @@ type FakeCommandService struct { subscribeArgsForCall []struct { arg1 context.Context } + UpdateClientStub func(v1.CommandServiceClient) + updateClientMutex sync.RWMutex + updateClientArgsForCall []struct { + arg1 v1.CommandServiceClient + } UpdateDataPlaneHealthStub func(context.Context, []*v1.InstanceHealth) error updateDataPlaneHealthMutex sync.RWMutex updateDataPlaneHealthArgsForCall []struct { @@ -290,6 +295,38 @@ func (fake *FakeCommandService) SubscribeArgsForCall(i int) context.Context { return argsForCall.arg1 } +func (fake *FakeCommandService) UpdateClient(arg1 v1.CommandServiceClient) { + fake.updateClientMutex.Lock() + fake.updateClientArgsForCall = append(fake.updateClientArgsForCall, struct { + arg1 v1.CommandServiceClient + }{arg1}) + stub := fake.UpdateClientStub + fake.recordInvocation("UpdateClient", []interface{}{arg1}) + fake.updateClientMutex.Unlock() + if stub != nil { + fake.UpdateClientStub(arg1) + } +} + +func (fake *FakeCommandService) UpdateClientCallCount() int { + fake.updateClientMutex.RLock() + defer fake.updateClientMutex.RUnlock() + return len(fake.updateClientArgsForCall) +} + +func (fake *FakeCommandService) UpdateClientCalls(stub func(v1.CommandServiceClient)) { + fake.updateClientMutex.Lock() + defer fake.updateClientMutex.Unlock() + fake.UpdateClientStub = stub +} + +func (fake *FakeCommandService) UpdateClientArgsForCall(i int) v1.CommandServiceClient { + fake.updateClientMutex.RLock() + defer fake.updateClientMutex.RUnlock() + argsForCall := fake.updateClientArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeCommandService) UpdateDataPlaneHealth(arg1 context.Context, arg2 []*v1.InstanceHealth) error { var arg2Copy []*v1.InstanceHealth if arg2 != nil { @@ -430,6 +467,8 @@ func (fake *FakeCommandService) Invocations() map[string][][]interface{} { defer fake.sendDataPlaneResponseMutex.RUnlock() fake.subscribeMutex.RLock() defer fake.subscribeMutex.RUnlock() + fake.updateClientMutex.RLock() + defer fake.updateClientMutex.RUnlock() fake.updateDataPlaneHealthMutex.RLock() defer fake.updateDataPlaneHealthMutex.RUnlock() fake.updateDataPlaneStatusMutex.RLock() diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 4bf989646..5119e7415 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -901,7 +901,8 @@ func createConfig() *Config { Type: Grpc, }, Auth: &AuthConfig{ - Token: "1234", + Token: "1234", + TokenPath: "path/to/my_token", }, TLS: &TLSConfig{ Cert: "some.cert", diff --git a/internal/config/testdata/nginx-agent.conf b/internal/config/testdata/nginx-agent.conf index f138afac9..d9bce7f50 100644 --- a/internal/config/testdata/nginx-agent.conf +++ b/internal/config/testdata/nginx-agent.conf @@ -54,6 +54,7 @@ command: type: grpc auth: token: "1234" + tokenpath: "path/to/my_token" tls: cert: "some.cert" key: "some.key" diff --git a/internal/file/file_manager_service.go b/internal/file/file_manager_service.go index c107bbe1c..cda8353e5 100644 --- a/internal/file/file_manager_service.go +++ b/internal/file/file_manager_service.go @@ -54,6 +54,7 @@ type ( UpdateCurrentFilesOnDisk(updateFiles map[string]*mpi.File) DetermineFileActions(currentFiles, modifiedFiles map[string]*mpi.File) (map[string]*mpi.File, map[string][]byte, error) + IsConnected() bool SetIsConnected(isConnected bool) } ) @@ -271,6 +272,10 @@ func (fms *FileManagerService) UpdateFile( return err } +func (fms *FileManagerService) IsConnected() bool { + return fms.isConnected.Load() +} + func (fms *FileManagerService) SetIsConnected(isConnected bool) { fms.isConnected.Store(isConnected) } diff --git a/internal/file/file_plugin.go b/internal/file/file_plugin.go index 28a24c6c2..212b85894 100644 --- a/internal/file/file_plugin.go +++ b/internal/file/file_plugin.go @@ -62,6 +62,8 @@ func (fp *FilePlugin) Info() *bus.Info { func (fp *FilePlugin) Process(ctx context.Context, msg *bus.Message) { switch msg.Topic { + case bus.ConnectionResetTopic: + fp.handleConnectionReset(ctx, msg) case bus.ConnectionCreatedTopic: fp.fileManagerService.SetIsConnected(true) case bus.NginxConfigUpdateTopic: @@ -81,6 +83,7 @@ func (fp *FilePlugin) Process(ctx context.Context, msg *bus.Message) { func (fp *FilePlugin) Subscriptions() []string { return []string{ + bus.ConnectionResetTopic, bus.ConnectionCreatedTopic, bus.NginxConfigUpdateTopic, bus.ConfigUploadRequestTopic, @@ -91,6 +94,24 @@ func (fp *FilePlugin) Subscriptions() []string { } } +func (fp *FilePlugin) handleConnectionReset(ctx context.Context, msg *bus.Message) { + slog.DebugContext(ctx, "File plugin received connection reset message") + if newConnection, ok := msg.Data.(grpc.GrpcConnectionInterface); ok { + var reconnect bool + err := fp.conn.Close(ctx) + if err != nil { + slog.ErrorContext(ctx, "File plugin: unable to close connection", "error", err) + } + fp.conn = newConnection + + reconnect = fp.fileManagerService.IsConnected() + fp.fileManagerService = NewFileManagerService(fp.conn.FileServiceClient(), fp.config) + fp.fileManagerService.SetIsConnected(reconnect) + + slog.DebugContext(ctx, "File plugin: client reset successfully") + } +} + func (fp *FilePlugin) handleConfigApplyComplete(ctx context.Context, msg *bus.Message) { response, ok := msg.Data.(*mpi.DataPlaneResponse) diff --git a/internal/file/file_plugin_test.go b/internal/file/file_plugin_test.go index 8ff0a07b7..d550d8b3b 100644 --- a/internal/file/file_plugin_test.go +++ b/internal/file/file_plugin_test.go @@ -50,6 +50,7 @@ func TestFilePlugin_Subscriptions(t *testing.T) { assert.Equal( t, []string{ + bus.ConnectionResetTopic, bus.ConnectionCreatedTopic, bus.NginxConfigUpdateTopic, bus.ConfigUploadRequestTopic, diff --git a/internal/file/filefakes/fake_file_manager_service_interface.go b/internal/file/filefakes/fake_file_manager_service_interface.go index 4819613ca..588318877 100644 --- a/internal/file/filefakes/fake_file_manager_service_interface.go +++ b/internal/file/filefakes/fake_file_manager_service_interface.go @@ -44,6 +44,16 @@ type FakeFileManagerServiceInterface struct { result2 map[string][]byte result3 error } + IsConnectedStub func() bool + isConnectedMutex sync.RWMutex + isConnectedArgsForCall []struct { + } + isConnectedReturns struct { + result1 bool + } + isConnectedReturnsOnCall map[int]struct { + result1 bool + } RollbackStub func(context.Context, string) error rollbackMutex sync.RWMutex rollbackArgsForCall []struct { @@ -254,6 +264,59 @@ func (fake *FakeFileManagerServiceInterface) DetermineFileActionsReturnsOnCall(i }{result1, result2, result3} } +func (fake *FakeFileManagerServiceInterface) IsConnected() bool { + fake.isConnectedMutex.Lock() + ret, specificReturn := fake.isConnectedReturnsOnCall[len(fake.isConnectedArgsForCall)] + fake.isConnectedArgsForCall = append(fake.isConnectedArgsForCall, struct { + }{}) + stub := fake.IsConnectedStub + fakeReturns := fake.isConnectedReturns + fake.recordInvocation("IsConnected", []interface{}{}) + fake.isConnectedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeFileManagerServiceInterface) IsConnectedCallCount() int { + fake.isConnectedMutex.RLock() + defer fake.isConnectedMutex.RUnlock() + return len(fake.isConnectedArgsForCall) +} + +func (fake *FakeFileManagerServiceInterface) IsConnectedCalls(stub func() bool) { + fake.isConnectedMutex.Lock() + defer fake.isConnectedMutex.Unlock() + fake.IsConnectedStub = stub +} + +func (fake *FakeFileManagerServiceInterface) IsConnectedReturns(result1 bool) { + fake.isConnectedMutex.Lock() + defer fake.isConnectedMutex.Unlock() + fake.IsConnectedStub = nil + fake.isConnectedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeFileManagerServiceInterface) IsConnectedReturnsOnCall(i int, result1 bool) { + fake.isConnectedMutex.Lock() + defer fake.isConnectedMutex.Unlock() + fake.IsConnectedStub = nil + if fake.isConnectedReturnsOnCall == nil { + fake.isConnectedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isConnectedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeFileManagerServiceInterface) Rollback(arg1 context.Context, arg2 string) error { fake.rollbackMutex.Lock() ret, specificReturn := fake.rollbackReturnsOnCall[len(fake.rollbackArgsForCall)] @@ -521,6 +584,8 @@ func (fake *FakeFileManagerServiceInterface) Invocations() map[string][][]interf defer fake.configApplyMutex.RUnlock() fake.determineFileActionsMutex.RLock() defer fake.determineFileActionsMutex.RUnlock() + fake.isConnectedMutex.RLock() + defer fake.isConnectedMutex.RUnlock() fake.rollbackMutex.RLock() defer fake.rollbackMutex.RUnlock() fake.setIsConnectedMutex.RLock() diff --git a/internal/grpc/grpc.go b/internal/grpc/grpc.go index ebdf8693e..550dcc7b1 100644 --- a/internal/grpc/grpc.go +++ b/internal/grpc/grpc.go @@ -41,7 +41,6 @@ type ( CommandServiceClient() mpi.CommandServiceClient FileServiceClient() mpi.FileServiceClient Close(ctx context.Context) error - Restart(ctx context.Context) (*GrpcConnection, error) } GrpcConnection struct { @@ -68,6 +67,7 @@ var ( _ GrpcConnectionInterface = (*GrpcConnection)(nil) ) +// nolint: ireturn func NewGrpcConnection(ctx context.Context, agentConfig *config.Config) (*GrpcConnection, error) { if agentConfig == nil || agentConfig.Command.Server.Type != config.Grpc { return nil, errors.New("invalid command server settings") @@ -131,22 +131,6 @@ func (gc *GrpcConnection) Close(ctx context.Context) error { return nil } -func (gc *GrpcConnection) Restart(ctx context.Context) (*GrpcConnection, error) { - slog.InfoContext(ctx, "Restarting grpc connection") - err := gc.Close(ctx) - if err != nil { - return nil, err - } - - slog.InfoContext(ctx, "Creating grpc connection") - newConn, err := NewGrpcConnection(ctx, gc.config) - if err != nil { - return nil, err - } - - return newConn, nil -} - func (w *wrappedStream) RecvMsg(message any) error { err := w.ClientStream.RecvMsg(message) if err == nil { diff --git a/internal/grpc/grpcfakes/fake_grpc_connection_interface.go b/internal/grpc/grpcfakes/fake_grpc_connection_interface.go index 6d487dbd7..6b2263cd7 100644 --- a/internal/grpc/grpcfakes/fake_grpc_connection_interface.go +++ b/internal/grpc/grpcfakes/fake_grpc_connection_interface.go @@ -41,19 +41,6 @@ type FakeGrpcConnectionInterface struct { fileServiceClientReturnsOnCall map[int]struct { result1 v1.FileServiceClient } - RestartStub func(context.Context) (*grpc.GrpcConnection, error) - restartMutex sync.RWMutex - restartArgsForCall []struct { - arg1 context.Context - } - restartReturns struct { - result1 *grpc.GrpcConnection - result2 error - } - restartReturnsOnCall map[int]struct { - result1 *grpc.GrpcConnection - result2 error - } invocations map[string][][]interface{} invocationsMutex sync.RWMutex } @@ -225,70 +212,6 @@ func (fake *FakeGrpcConnectionInterface) FileServiceClientReturnsOnCall(i int, r }{result1} } -func (fake *FakeGrpcConnectionInterface) Restart(arg1 context.Context) (*grpc.GrpcConnection, error) { - fake.restartMutex.Lock() - ret, specificReturn := fake.restartReturnsOnCall[len(fake.restartArgsForCall)] - fake.restartArgsForCall = append(fake.restartArgsForCall, struct { - arg1 context.Context - }{arg1}) - stub := fake.RestartStub - fakeReturns := fake.restartReturns - fake.recordInvocation("Restart", []interface{}{arg1}) - fake.restartMutex.Unlock() - if stub != nil { - return stub(arg1) - } - if specificReturn { - return ret.result1, ret.result2 - } - return fakeReturns.result1, fakeReturns.result2 -} - -func (fake *FakeGrpcConnectionInterface) RestartCallCount() int { - fake.restartMutex.RLock() - defer fake.restartMutex.RUnlock() - return len(fake.restartArgsForCall) -} - -func (fake *FakeGrpcConnectionInterface) RestartCalls(stub func(context.Context) (*grpc.GrpcConnection, error)) { - fake.restartMutex.Lock() - defer fake.restartMutex.Unlock() - fake.RestartStub = stub -} - -func (fake *FakeGrpcConnectionInterface) RestartArgsForCall(i int) context.Context { - fake.restartMutex.RLock() - defer fake.restartMutex.RUnlock() - argsForCall := fake.restartArgsForCall[i] - return argsForCall.arg1 -} - -func (fake *FakeGrpcConnectionInterface) RestartReturns(result1 *grpc.GrpcConnection, result2 error) { - fake.restartMutex.Lock() - defer fake.restartMutex.Unlock() - fake.RestartStub = nil - fake.restartReturns = struct { - result1 *grpc.GrpcConnection - result2 error - }{result1, result2} -} - -func (fake *FakeGrpcConnectionInterface) RestartReturnsOnCall(i int, result1 *grpc.GrpcConnection, result2 error) { - fake.restartMutex.Lock() - defer fake.restartMutex.Unlock() - fake.RestartStub = nil - if fake.restartReturnsOnCall == nil { - fake.restartReturnsOnCall = make(map[int]struct { - result1 *grpc.GrpcConnection - result2 error - }) - } - fake.restartReturnsOnCall[i] = struct { - result1 *grpc.GrpcConnection - result2 error - }{result1, result2} -} - func (fake *FakeGrpcConnectionInterface) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() @@ -298,8 +221,6 @@ func (fake *FakeGrpcConnectionInterface) Invocations() map[string][][]interface{ defer fake.commandServiceClientMutex.RUnlock() fake.fileServiceClientMutex.RLock() defer fake.fileServiceClientMutex.RUnlock() - fake.restartMutex.RLock() - defer fake.restartMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} for key, value := range fake.invocations { copiedInvocations[key] = value diff --git a/internal/watcher/credentials/credential_watcher_service.go b/internal/watcher/credentials/credential_watcher_service.go new file mode 100644 index 000000000..440c029db --- /dev/null +++ b/internal/watcher/credentials/credential_watcher_service.go @@ -0,0 +1,194 @@ +// Copyright (c) F5, Inc. +// +// This source code is licensed under the Apache License, Version 2.0 license found in the +// LICENSE file in the root directory of this source tree. + +package credentials + +import ( + "context" + "log/slog" + "slices" + "sync" + "sync/atomic" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/nginx/agent/v3/internal/config" + "github.com/nginx/agent/v3/internal/logger" +) + +const ( + monitoringInterval = 5 * time.Second +) + +var emptyEvent = fsnotify.Event{ + Name: "", + Op: 0, +} + +type CredentialUpdateMessage struct { + CorrelationID slog.Attr +} + +type CredentialWatcherService struct { + enabled *atomic.Bool + agentConfig *config.Config + watcher *fsnotify.Watcher + filesBeingWatched *sync.Map + filesChanged *atomic.Bool +} + +func NewCredentialWatcherService(agentConfig *config.Config) *CredentialWatcherService { + enabled := &atomic.Bool{} + enabled.Store(true) + + filesChanged := &atomic.Bool{} + filesChanged.Store(false) + + return &CredentialWatcherService{ + enabled: enabled, + agentConfig: agentConfig, + filesBeingWatched: &sync.Map{}, + filesChanged: filesChanged, + } +} + +func (cws *CredentialWatcherService) Watch(ctx context.Context, ch chan<- CredentialUpdateMessage) { + slog.DebugContext(ctx, "Starting credential watcher monitoring") + + ticker := time.NewTicker(monitoringInterval) + watcher, err := fsnotify.NewWatcher() + if err != nil { + slog.ErrorContext(ctx, "Failed to create credential watcher", "error", err) + return + } + + cws.watcher = watcher + + cws.watchFiles(ctx, credentialPaths(cws.agentConfig)) + + for { + select { + case <-ctx.Done(): + closeError := cws.watcher.Close() + if closeError != nil { + slog.ErrorContext(ctx, "Unable to close credential watcher", "error", closeError) + } + + return + case event := <-cws.watcher.Events: + cws.handleEvent(ctx, event) + case <-ticker.C: + cws.checkForUpdates(ctx, ch) + case watcherError := <-cws.watcher.Errors: + slog.ErrorContext(ctx, "Unexpected error in credential watcher", "error", watcherError) + } + } +} + +func (cws *CredentialWatcherService) SetEnabled(enabled bool) { + cws.enabled.Store(enabled) +} + +func (cws *CredentialWatcherService) addWatcher(ctx context.Context, filePath string) { + if !cws.enabled.Load() { + slog.DebugContext(ctx, "Credential watcher is disabled") + + return + } + + if cws.isWatching(filePath) { + slog.DebugContext( + ctx, "Credential watcher is already watching ", "path", filePath) + + return + } + + if err := cws.watcher.Add(filePath); err != nil { + slog.ErrorContext(ctx, "Failed to add credential watcher", "path", filePath, "error", err) + removeError := cws.watcher.Remove(filePath) + if removeError != nil { + slog.ErrorContext( + ctx, "Failed to remove credential watcher", "path", filePath, "error", removeError) + } + + return + } + cws.filesBeingWatched.Store(filePath, true) + slog.DebugContext(ctx, "Credential watcher has been added", "path", filePath) +} + +func (cws *CredentialWatcherService) watchFiles(ctx context.Context, files []string) { + slog.DebugContext(ctx, "Creating credential watchers") + + for _, filePath := range files { + cws.addWatcher(ctx, filePath) + } +} + +func (cws *CredentialWatcherService) isWatching(path string) bool { + v, _ := cws.filesBeingWatched.LoadOrStore(path, false) + + if value, ok := v.(bool); ok { + return value + } + + return false +} + +func (cws *CredentialWatcherService) handleEvent(ctx context.Context, event fsnotify.Event) { + if cws.enabled.Load() { + if isEventSkippable(event) { + slog.DebugContext(ctx, "Skipping FSNotify event", "event", event) + return + } + + slog.DebugContext(ctx, "Processing FSNotify event", "event", event) + + switch { + case event.Has(fsnotify.Remove): + fallthrough + case event.Has(fsnotify.Rename): + if !slices.Contains(cws.watcher.WatchList(), event.Name) { + cws.filesBeingWatched.Store(event.Name, false) + } + cws.addWatcher(ctx, event.Name) + } + + cws.filesChanged.Store(true) + } +} + +func (cws *CredentialWatcherService) checkForUpdates(ctx context.Context, ch chan<- CredentialUpdateMessage) { + if cws.filesChanged.Load() { + newCtx := context.WithValue( + ctx, + logger.CorrelationIDContextKey, + slog.Any(logger.CorrelationIDKey, logger.GenerateCorrelationID()), + ) + + slog.DebugContext(ctx, "Credential watcher has detected changes") + ch <- CredentialUpdateMessage{CorrelationID: logger.GetCorrelationIDAttr(newCtx)} + cws.filesChanged.Store(false) + } +} + +func credentialPaths(agentConfig *config.Config) []string { + var paths []string + + if agentConfig.Command.Auth != nil { + if agentConfig.Command.Auth.TokenPath != "" { + paths = append(paths, agentConfig.Command.Auth.TokenPath) + } + } + + return paths +} + +func isEventSkippable(event fsnotify.Event) bool { + return event == emptyEvent || + event.Name == "" || + event.Has(fsnotify.Chmod) || + event.Has(fsnotify.Create) +} diff --git a/internal/watcher/credentials/credential_watcher_service_test.go b/internal/watcher/credentials/credential_watcher_service_test.go new file mode 100644 index 000000000..23fd3b9e5 --- /dev/null +++ b/internal/watcher/credentials/credential_watcher_service_test.go @@ -0,0 +1,233 @@ +// Copyright (c) F5, Inc. +// +// This source code is licensed under the Apache License, Version 2.0 license found in the +// LICENSE file in the root directory of this source tree. + +package credentials + +import ( + "context" + "fmt" + "os" + "path" + "testing" + "time" + + "github.com/nginx/agent/v3/internal/config" + + "github.com/fsnotify/fsnotify" + "github.com/nginx/agent/v3/test/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCredentialWatcherService_TestNewCredentialWatcherService(t *testing.T) { + credentialWatcherService := NewCredentialWatcherService(types.AgentConfig()) + + assert.Empty(t, credentialWatcherService.filesBeingWatched) + assert.True(t, credentialWatcherService.enabled.Load()) + assert.False(t, credentialWatcherService.filesChanged.Load()) +} + +func TestCredentialWatcherService_SetEnabled(t *testing.T) { + credentialWatcherService := NewCredentialWatcherService(types.AgentConfig()) + assert.True(t, credentialWatcherService.enabled.Load()) + + credentialWatcherService.SetEnabled(false) + assert.False(t, credentialWatcherService.enabled.Load()) + + credentialWatcherService.SetEnabled(true) + assert.True(t, credentialWatcherService.enabled.Load()) +} + +func TestCredentialWatcherService_Watch(t *testing.T) { + ctx := context.Background() + cws := NewCredentialWatcherService(types.AgentConfig()) + watcher, err := fsnotify.NewWatcher() + require.NoError(t, err) + cws.watcher = watcher + + cuc := make(chan CredentialUpdateMessage) + + name := path.Join(os.TempDir(), "test_file") + _, err = os.Create(name) + require.NoError(t, err) + defer os.Remove(name) + + cws.agentConfig.Command.Auth.TokenPath = name + cws.filesChanged.Store(true) + go cws.Watch(ctx, cuc) + + select { + case <-ctx.Done(): + t.Error("context done") + case <-cuc: + assert.True(t, cws.isWatching(name)) + case <-time.After(2 * monitoringInterval): + t.Error("Timed out waiting for credential watch") + } + + func() { + cws.watcher.Errors <- fmt.Errorf("watch error") + }() +} + +func TestCredentialWatcherService_isWatching(t *testing.T) { + cws := NewCredentialWatcherService(types.AgentConfig()) + assert.False(t, cws.isWatching("test-file")) + cws.filesBeingWatched.Store("test-file", true) + assert.True(t, cws.isWatching("test-file")) + cws.filesBeingWatched.Store("test-file", false) + assert.False(t, cws.isWatching("test-file")) +} + +func TestCredentialWatcherService_isEventSkippable(t *testing.T) { + assert.False(t, isEventSkippable(fsnotify.Event{Name: "testWriteEvent", Op: fsnotify.Write})) + assert.True(t, isEventSkippable(fsnotify.Event{Name: "", Op: 0})) + assert.True(t, isEventSkippable(fsnotify.Event{Name: "", Op: fsnotify.Write})) + assert.True(t, isEventSkippable(fsnotify.Event{Op: fsnotify.Chmod})) + assert.True(t, isEventSkippable(fsnotify.Event{Op: fsnotify.Rename})) + assert.True(t, isEventSkippable(fsnotify.Event{Op: fsnotify.Create})) +} + +func TestCredentialWatcherService_addWatcher(t *testing.T) { + ctx := context.Background() + cws := NewCredentialWatcherService(types.AgentConfig()) + watcher, err := fsnotify.NewWatcher() + require.NoError(t, err) + cws.watcher = watcher + + name := path.Join(os.TempDir(), "test_file") + _, err = os.Create(name) + require.NoError(t, err) + defer os.Remove(name) + + cws.enabled.Store(false) + + cws.addWatcher(ctx, name) + require.False(t, cws.isWatching(name)) + + cws.enabled.Store(true) + + cws.addWatcher(ctx, name) + require.True(t, cws.isWatching(name)) + + cws.addWatcher(ctx, name) + require.True(t, cws.isWatching(name)) + + name = path.Join(os.TempDir(), "noexist_file") + cws.addWatcher(ctx, name) + require.False(t, cws.isWatching(name)) +} + +func TestCredentialWatcherService_watchFiles(t *testing.T) { + var files []string + + ctx := context.Background() + cws := NewCredentialWatcherService(types.AgentConfig()) + watcher, err := fsnotify.NewWatcher() + require.NoError(t, err) + cws.watcher = watcher + + files = append(files, path.Join(os.TempDir(), "test_file1")) + files = append(files, path.Join(os.TempDir(), "test_file2")) + files = append(files, path.Join(os.TempDir(), "test_file3")) + + for _, file := range files { + _, err = os.Create(file) + require.NoError(t, err) + } + + cws.watchFiles(ctx, files) + require.True(t, cws.isWatching(path.Join(os.TempDir(), "test_file1"))) + require.True(t, cws.isWatching(path.Join(os.TempDir(), "test_file2"))) + require.True(t, cws.isWatching(path.Join(os.TempDir(), "test_file3"))) + + for _, file := range files { + err = os.Remove(file) + cws.filesBeingWatched.Delete(file) + require.NoError(t, err) + } + + require.False(t, cws.isWatching(path.Join(os.TempDir(), "test_file1"))) + require.False(t, cws.isWatching(path.Join(os.TempDir(), "test_file2"))) + require.False(t, cws.isWatching(path.Join(os.TempDir(), "test_file3"))) +} + +func TestCredentialWatcherService_checkForUpdates(t *testing.T) { + ctx := context.Background() + cws := NewCredentialWatcherService(types.AgentConfig()) + watcher, err := fsnotify.NewWatcher() + require.NoError(t, err) + cws.watcher = watcher + + name := path.Join(os.TempDir(), "test_file") + _, err = os.Create(name) + require.NoError(t, err) + cws.addWatcher(ctx, name) + require.True(t, cws.isWatching(name)) + + cws.filesChanged.Store(true) + ch := make(chan CredentialUpdateMessage) + go cws.checkForUpdates(ctx, ch) + + select { + case <-ctx.Done(): + t.Error(ctx.Err()) + case cu := <-ch: + t.Logf("check for update success %v", cu) + case <-time.After(2 * monitoringInterval): + t.Error("timeout waiting for update") + } +} + +func TestCredentialWatcherService_handleEvent(t *testing.T) { + ctx := context.Background() + cws := NewCredentialWatcherService(types.AgentConfig()) + watcher, err := fsnotify.NewWatcher() + require.NoError(t, err) + cws.watcher = watcher + + cws.handleEvent(ctx, fsnotify.Event{Name: "test-write", Op: fsnotify.Chmod}) + assert.False(t, cws.filesChanged.Load()) + cws.handleEvent(ctx, fsnotify.Event{Name: "test-create", Op: fsnotify.Create}) + assert.False(t, cws.filesChanged.Load()) + cws.handleEvent(ctx, fsnotify.Event{Name: "test-remove", Op: fsnotify.Remove}) + assert.True(t, cws.filesChanged.Load()) + cws.handleEvent(ctx, fsnotify.Event{Name: "test-rename", Op: fsnotify.Rename}) + assert.True(t, cws.filesChanged.Load()) + cws.handleEvent(ctx, fsnotify.Event{Name: "test-write", Op: fsnotify.Write}) + assert.True(t, cws.filesChanged.Load()) +} + +func Test_credentialPaths(t *testing.T) { + tests := []struct { + name string + agentConfig *config.Config + want []string + }{ + { + name: "Test 1: Returns expected paths when Auth TokenPath is set", + agentConfig: types.AgentConfig(), + want: []string{ + "/tmp/token", + }, + }, + { + name: "Test 2: Returns empty slice when Auth TokenPath is not set", + agentConfig: &config.Config{ + Command: &config.Command{ + Server: nil, + Auth: nil, + TLS: nil, + }, + }, + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, credentialPaths(tt.agentConfig), "credentialPaths(%v)", tt.agentConfig) + }) + } +} diff --git a/internal/watcher/process/process_operator_test.go b/internal/watcher/process/process_operator_test.go new file mode 100644 index 000000000..8150ee1d1 --- /dev/null +++ b/internal/watcher/process/process_operator_test.go @@ -0,0 +1,17 @@ +// Copyright (c) F5, Inc. +// +// This source code is licensed under the Apache License, Version 2.0 license found in the +// LICENSE file in the root directory of this source tree. + +package process + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewProcessOperator(t *testing.T) { + po := NewProcessOperator() + assert.NotNil(t, po) +} diff --git a/internal/watcher/watcher_plugin.go b/internal/watcher/watcher_plugin.go index 82be0ed05..3e00d5459 100644 --- a/internal/watcher/watcher_plugin.go +++ b/internal/watcher/watcher_plugin.go @@ -11,6 +11,10 @@ import ( "slices" "sync" + "github.com/nginx/agent/v3/internal/grpc" + + "github.com/nginx/agent/v3/internal/watcher/credentials" + mpi "github.com/nginx/agent/v3/api/grpc/mpi/v1" "github.com/nginx/agent/v3/internal/watcher/file" @@ -34,10 +38,12 @@ type ( instanceWatcherService instanceWatcherServiceInterface healthWatcherService *health.HealthWatcherService fileWatcherService *file.FileWatcherService + credentialWatcherService credentialWatcherServiceInterface instanceUpdatesChannel chan instance.InstanceUpdatesMessage nginxConfigContextChannel chan instance.NginxConfigContextMessage instanceHealthChannel chan health.InstanceHealthMessage fileUpdatesChannel chan file.FileUpdateMessage + credentialUpdatesChannel chan credentials.CredentialUpdateMessage cancel context.CancelFunc instancesWithConfigApplyInProgress []string watcherMutex sync.Mutex @@ -52,6 +58,14 @@ type ( ReparseConfig(ctx context.Context, instanceID string) ReparseConfigs(ctx context.Context) } + + credentialWatcherServiceInterface interface { + Watch( + ctx context.Context, + credentialUpdateChannel chan<- credentials.CredentialUpdateMessage, + ) + SetEnabled(enabled bool) + } ) var _ bus.Plugin = (*Watcher)(nil) @@ -62,10 +76,12 @@ func NewWatcher(agentConfig *config.Config) *Watcher { instanceWatcherService: instance.NewInstanceWatcherService(agentConfig), healthWatcherService: health.NewHealthWatcherService(agentConfig), fileWatcherService: file.NewFileWatcherService(agentConfig), + credentialWatcherService: credentials.NewCredentialWatcherService(agentConfig), instanceUpdatesChannel: make(chan instance.InstanceUpdatesMessage), nginxConfigContextChannel: make(chan instance.NginxConfigContextMessage), instanceHealthChannel: make(chan health.InstanceHealthMessage), fileUpdatesChannel: make(chan file.FileUpdateMessage), + credentialUpdatesChannel: make(chan credentials.CredentialUpdateMessage), instancesWithConfigApplyInProgress: []string{}, watcherMutex: sync.Mutex{}, } @@ -82,6 +98,7 @@ func (w *Watcher) Init(ctx context.Context, messagePipe bus.MessagePipeInterface go w.instanceWatcherService.Watch(watcherContext, w.instanceUpdatesChannel, w.nginxConfigContextChannel) go w.healthWatcherService.Watch(watcherContext, w.instanceHealthChannel) + go w.credentialWatcherService.Watch(watcherContext, w.credentialUpdatesChannel) if w.agentConfig.IsFeatureEnabled(pkgConfig.FeatureFileWatcher) { go w.fileWatcherService.Watch(watcherContext, w.fileUpdatesChannel) @@ -110,6 +127,8 @@ func (*Watcher) Info() *bus.Info { func (w *Watcher) Process(ctx context.Context, msg *bus.Message) { switch msg.Topic { + case bus.CredentialUpdatedTopic: + w.handleCredentialUpdate(ctx) case bus.ConfigApplyRequestTopic: w.handleConfigApplyRequest(ctx, msg) case bus.ConfigApplySuccessfulTopic: @@ -125,6 +144,7 @@ func (w *Watcher) Process(ctx context.Context, msg *bus.Message) { func (*Watcher) Subscriptions() []string { return []string{ + bus.CredentialUpdatedTopic, bus.ConfigApplyRequestTopic, bus.ConfigApplySuccessfulTopic, bus.ConfigApplyCompleteTopic, @@ -154,6 +174,7 @@ func (w *Watcher) handleConfigApplyRequest(ctx context.Context, msg *bus.Message w.watcherMutex.Lock() defer w.watcherMutex.Unlock() w.instancesWithConfigApplyInProgress = append(w.instancesWithConfigApplyInProgress, instanceID) + w.fileWatcherService.SetEnabled(false) } @@ -175,6 +196,7 @@ func (w *Watcher) handleConfigApplySuccess(ctx context.Context, msg *bus.Message return element == instanceID }, ) + w.fileWatcherService.SetEnabled(true) w.watcherMutex.Unlock() @@ -206,14 +228,38 @@ func (w *Watcher) handleConfigApplyComplete(ctx context.Context, msg *bus.Messag return element == instanceID }, ) + w.fileWatcherService.SetEnabled(true) } +func (w *Watcher) handleCredentialUpdate(ctx context.Context) { + slog.DebugContext(ctx, "Received credential update topic") + + w.watcherMutex.Lock() + conn, err := grpc.NewGrpcConnection(ctx, w.agentConfig) + if err != nil { + slog.ErrorContext(ctx, "Unable to create new grpc connection", "error", err) + w.watcherMutex.Unlock() + + return + } + w.watcherMutex.Unlock() + w.messagePipe.Process(ctx, &bus.Message{ + Topic: bus.ConnectionResetTopic, Data: conn, + }) +} + func (w *Watcher) monitorWatchers(ctx context.Context) { for { select { case <-ctx.Done(): return + case message := <-w.credentialUpdatesChannel: + slog.DebugContext(ctx, "Received credential update event") + newCtx := context.WithValue(ctx, logger.CorrelationIDContextKey, message.CorrelationID) + w.messagePipe.Process(newCtx, &bus.Message{ + Topic: bus.CredentialUpdatedTopic, Data: nil, + }) case message := <-w.instanceUpdatesChannel: newCtx := context.WithValue(ctx, logger.CorrelationIDContextKey, message.CorrelationID) w.handleInstanceUpdates(newCtx, message) @@ -244,7 +290,6 @@ func (w *Watcher) monitorWatchers(ctx context.Context) { w.messagePipe.Process(newCtx, &bus.Message{ Topic: bus.InstanceHealthTopic, Data: message.InstanceHealth, }) - case message := <-w.fileUpdatesChannel: newCtx := context.WithValue(ctx, logger.CorrelationIDContextKey, message.CorrelationID) // Running this in a separate go routine otherwise we get into a deadlock diff --git a/internal/watcher/watcher_plugin_test.go b/internal/watcher/watcher_plugin_test.go index 16a3d4bb1..2a739d530 100644 --- a/internal/watcher/watcher_plugin_test.go +++ b/internal/watcher/watcher_plugin_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/nginx/agent/v3/internal/watcher/credentials" + "github.com/nginx/agent/v3/internal/bus/busfakes" "google.golang.org/protobuf/types/known/timestamppb" @@ -71,11 +73,16 @@ func TestWatcher_Init(t *testing.T) { InstanceHealth: []*mpi.InstanceHealth{}, } + credentialUpdateMessage := credentials.CredentialUpdateMessage{ + CorrelationID: logger.GenerateCorrelationID(), + } + watcherPlugin.instanceUpdatesChannel <- instanceUpdatesMessage watcherPlugin.nginxConfigContextChannel <- nginxConfigContextMessage watcherPlugin.instanceHealthChannel <- instanceHealthMessage + watcherPlugin.credentialUpdatesChannel <- credentialUpdateMessage - assert.Eventually(t, func() bool { return len(messagePipe.GetMessages()) == 5 }, 2*time.Second, 10*time.Millisecond) + assert.Eventually(t, func() bool { return len(messagePipe.GetMessages()) == 6 }, 2*time.Second, 10*time.Millisecond) messages = messagePipe.GetMessages() assert.Equal( @@ -103,6 +110,9 @@ func TestWatcher_Init(t *testing.T) { &bus.Message{Topic: bus.InstanceHealthTopic, Data: instanceHealthMessage.InstanceHealth}, messages[4], ) + assert.Equal(t, + &bus.Message{Topic: bus.CredentialUpdatedTopic, Data: nil}, + messages[5]) } func TestWatcher_Info(t *testing.T) { @@ -110,6 +120,24 @@ func TestWatcher_Info(t *testing.T) { assert.Equal(t, &bus.Info{Name: "watcher"}, watcherPlugin.Info()) } +func TestWatcher_Process_CredentialUpdatedTopic(t *testing.T) { + ctx := context.Background() + + watcherPlugin := NewWatcher(types.AgentConfig()) + + messagePipe := busfakes.NewFakeMessagePipe() + + err := watcherPlugin.Init(ctx, messagePipe) + require.NoError(t, err) + + message := &bus.Message{ + Topic: bus.CredentialUpdatedTopic, + Data: nil, + } + + watcherPlugin.Process(ctx, message) +} + func TestWatcher_Process_ConfigApplyRequestTopic(t *testing.T) { ctx := context.Background() data := &mpi.ManagementPlaneRequest{ @@ -201,6 +229,7 @@ func TestWatcher_Subscriptions(t *testing.T) { assert.Equal( t, []string{ + bus.CredentialUpdatedTopic, bus.ConfigApplyRequestTopic, bus.ConfigApplySuccessfulTopic, bus.ConfigApplyCompleteTopic, diff --git a/test/types/config.go b/test/types/config.go index c03997c44..729a20825 100644 --- a/test/types/config.go +++ b/test/types/config.go @@ -134,7 +134,7 @@ func AgentConfig() *config.Config { }, Auth: &config.AuthConfig{ Token: "1234", - TokenPath: "", + TokenPath: "/tmp/token", }, TLS: &config.TLSConfig{ Cert: "cert.pem",