diff --git a/internal/command/command_plugin.go b/internal/command/command_plugin.go index 2ce0c6842..feb1c3b6d 100644 --- a/internal/command/command_plugin.go +++ b/internal/command/command_plugin.go @@ -120,7 +120,9 @@ func (cp *CommandPlugin) Process(ctx context.Context, msg *bus.Message) { if logger.ServerType(ctxWithMetadata) == cp.commandServerType.String() { switch msg.Topic { case bus.ConnectionResetTopic: - cp.processConnectionReset(ctxWithMetadata, msg) + // Running as a separate go routine so that the command plugin can continue to process data plane responses + // while the connection reset is in progress + go cp.processConnectionReset(ctxWithMetadata, msg) case bus.ResourceUpdateTopic: cp.processResourceUpdate(ctxWithMetadata, msg) case bus.InstanceHealthTopic: @@ -254,11 +256,19 @@ func (cp *CommandPlugin) processConnectionReset(ctx context.Context, msg *bus.Me slog.DebugContext(ctx, "Command plugin received connection reset message") if newConnection, ok := msg.Data.(grpc.GrpcConnectionInterface); ok { - slog.DebugContext(ctx, "Canceling Subscribe after connection reset") ctxWithMetadata := cp.config.NewContextWithLabels(ctx) cp.subscribeMutex.Lock() defer cp.subscribeMutex.Unlock() + // Update the command service with the new client first + err := cp.commandService.UpdateClient(ctxWithMetadata, newConnection.CommandServiceClient()) + if err != nil { + slog.ErrorContext(ctx, "Failed to reset connection", "error", err) + return + } + + // Once the command service is updated, we close the old connection + slog.DebugContext(ctx, "Canceling Subscribe after connection reset") if cp.subscribeCancel != nil { cp.subscribeCancel() slog.DebugContext(ctxWithMetadata, "Successfully canceled subscribe after connection reset") @@ -270,12 +280,6 @@ func (cp *CommandPlugin) processConnectionReset(ctx context.Context, msg *bus.Me } cp.conn = newConnection - err := cp.commandService.UpdateClient(ctx, cp.conn.CommandServiceClient()) - if err != nil { - slog.ErrorContext(ctx, "Failed to reset connection", "error", err) - return - } - slog.DebugContext(ctxWithMetadata, "Starting new subscribe after connection reset") subscribeCtx, cp.subscribeCancel = context.WithCancel(ctxWithMetadata) go cp.commandService.Subscribe(subscribeCtx) diff --git a/internal/command/command_plugin_test.go b/internal/command/command_plugin_test.go index 87af33a9a..ef8782614 100644 --- a/internal/command/command_plugin_test.go +++ b/internal/command/command_plugin_test.go @@ -164,7 +164,14 @@ func TestCommandPlugin_Process(t *testing.T) { Topic: bus.ConnectionResetTopic, Data: commandPlugin.conn, }) - require.Equal(t, 1, fakeCommandService.UpdateClientCallCount()) + + // Separate goroutine is executed so need to wait for it to complete + assert.Eventually( + t, + func() bool { return fakeCommandService.UpdateClientCallCount() == 1 }, + 2*time.Second, + 10*time.Millisecond, + ) } func TestCommandPlugin_monitorSubscribeChannel(t *testing.T) { diff --git a/internal/command/command_service.go b/internal/command/command_service.go index 97cd8a9ba..61daf603f 100644 --- a/internal/command/command_service.go +++ b/internal/command/command_service.go @@ -12,6 +12,7 @@ import ( "log/slog" "sync" "sync/atomic" + "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -31,9 +32,9 @@ import ( var _ commandService = (*CommandService)(nil) -const ( - createConnectionMaxElapsedTime = 0 -) +const createConnectionMaxElapsedTime = 0 + +var timeToWaitBetweenChecks = 5 * time.Second type ( CommandService struct { @@ -41,8 +42,10 @@ type ( subscribeClient mpi.CommandService_SubscribeClient agentConfig *config.Config isConnected *atomic.Bool + connectionResetInProgress *atomic.Bool subscribeChannel chan *mpi.ManagementPlaneRequest configApplyRequestQueue map[string][]*mpi.ManagementPlaneRequest // key is the instance ID + requestsInProgress map[string]*mpi.ManagementPlaneRequest // key is the correlation ID resource *mpi.Resource subscribeClientMutex sync.Mutex configApplyRequestQueueMutex sync.Mutex @@ -56,19 +59,16 @@ func NewCommandService( agentConfig *config.Config, subscribeChannel chan *mpi.ManagementPlaneRequest, ) *CommandService { - isConnected := &atomic.Bool{} - isConnected.Store(false) - - commandService := &CommandService{ - commandServiceClient: commandServiceClient, - agentConfig: agentConfig, - isConnected: isConnected, - subscribeChannel: subscribeChannel, - configApplyRequestQueue: make(map[string][]*mpi.ManagementPlaneRequest), - resource: &mpi.Resource{}, + return &CommandService{ + commandServiceClient: commandServiceClient, + agentConfig: agentConfig, + isConnected: &atomic.Bool{}, + connectionResetInProgress: &atomic.Bool{}, + subscribeChannel: subscribeChannel, + configApplyRequestQueue: make(map[string][]*mpi.ManagementPlaneRequest), + resource: &mpi.Resource{}, + requestsInProgress: make(map[string]*mpi.ManagementPlaneRequest), } - - return commandService } func (cs *CommandService) IsConnected() bool { @@ -181,6 +181,11 @@ func (cs *CommandService) SendDataPlaneResponse(ctx context.Context, response *m return err } + if response.GetCommandResponse().GetStatus() == mpi.CommandResponse_COMMAND_STATUS_OK || + response.GetCommandResponse().GetStatus() == mpi.CommandResponse_COMMAND_STATUS_FAILURE { + delete(cs.requestsInProgress, response.GetMessageMeta().GetCorrelationId()) + } + return backoff.Retry( cs.sendDataPlaneResponseCallback(ctx, response), backoffHelpers.Context(backOffCtx, cs.agentConfig.Client.Backoff), @@ -272,6 +277,33 @@ func (cs *CommandService) CreateConnection( } func (cs *CommandService) UpdateClient(ctx context.Context, client mpi.CommandServiceClient) error { + cs.connectionResetInProgress.Store(true) + defer cs.connectionResetInProgress.Store(false) + + // Wait for any in-progress requests to complete before updating the client + start := time.Now() + + for len(cs.requestsInProgress) > 0 { + if time.Since(start) >= cs.agentConfig.Client.Grpc.ConnectionResetTimeout { + slog.WarnContext( + ctx, + "Timeout reached while waiting for in-progress requests to complete", + "number_of_requests_in_progress", len(cs.requestsInProgress), + ) + + break + } + + slog.InfoContext( + ctx, + "Waiting for in-progress requests to complete before updating command service gRPC client", + "max_wait_time", cs.agentConfig.Client.Grpc.ConnectionResetTimeout, + "number_of_requests_in_progress", len(cs.requestsInProgress), + ) + + time.Sleep(timeToWaitBetweenChecks) + } + cs.subscribeClientMutex.Lock() cs.commandServiceClient = client cs.subscribeClientMutex.Unlock() @@ -379,7 +411,7 @@ func (cs *CommandService) sendResponseForQueuedConfigApplyRequests( cs.configApplyRequestQueue[instanceID] = cs.configApplyRequestQueue[instanceID][indexOfConfigApplyRequest+1:] slog.DebugContext(ctx, "Removed config apply requests from queue", "queue", cs.configApplyRequestQueue[instanceID]) - if len(cs.configApplyRequestQueue[instanceID]) > 0 { + if len(cs.configApplyRequestQueue[instanceID]) > 0 && !cs.connectionResetInProgress.Load() { cs.subscribeChannel <- cs.configApplyRequestQueue[instanceID][len(cs.configApplyRequestQueue[instanceID])-1] } @@ -423,6 +455,12 @@ func (cs *CommandService) dataPlaneHealthCallback( //nolint:revive // cognitive complexity is 18 func (cs *CommandService) receiveCallback(ctx context.Context) func() error { return func() error { + if cs.connectionResetInProgress.Load() { + slog.DebugContext(ctx, "Connection reset in progress, skipping receive from subscribe stream") + + return nil + } + cs.subscribeClientMutex.Lock() if cs.subscribeClient == nil { @@ -463,6 +501,8 @@ func (cs *CommandService) receiveCallback(ctx context.Context) func() error { default: cs.subscribeChannel <- request } + + cs.requestsInProgress[request.GetMessageMeta().GetCorrelationId()] = request } return nil @@ -495,7 +535,7 @@ func (cs *CommandService) queueConfigApplyRequests(ctx context.Context, request instanceID := request.GetConfigApplyRequest().GetOverview().GetConfigVersion().GetInstanceId() cs.configApplyRequestQueue[instanceID] = append(cs.configApplyRequestQueue[instanceID], request) - if len(cs.configApplyRequestQueue[instanceID]) == 1 { + if len(cs.configApplyRequestQueue[instanceID]) == 1 && !cs.connectionResetInProgress.Load() { cs.subscribeChannel <- request } else { slog.DebugContext( diff --git a/internal/command/command_service_test.go b/internal/command/command_service_test.go index d91e9fe0f..f37f335d7 100644 --- a/internal/command/command_service_test.go +++ b/internal/command/command_service_test.go @@ -18,6 +18,7 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/nginx/agent/v3/internal/logger" + "github.com/nginx/agent/v3/pkg/id" "github.com/nginx/agent/v3/test/helpers" "github.com/nginx/agent/v3/test/stub" @@ -211,6 +212,37 @@ func TestCommandService_UpdateClient(t *testing.T) { assert.NotNil(t, commandService.commandServiceClient) } +func TestCommandService_UpdateClient_requestInProgress(t *testing.T) { + commandServiceClient := &v1fakes.FakeCommandServiceClient{} + ctx := context.Background() + + commandService := NewCommandService( + commandServiceClient, + types.AgentConfig(), + make(chan *mpi.ManagementPlaneRequest), + ) + + instanceID := id.GenerateMessageID() + + commandService.requestsInProgress[instanceID] = &mpi.ManagementPlaneRequest{} + timeToWaitBetweenChecks = 100 * time.Millisecond + + wg := sync.WaitGroup{} + wg.Add(1) + + var updateClientErr error + + go func() { + updateClientErr = commandService.UpdateClient(ctx, commandServiceClient) + wg.Done() + }() + + wg.Wait() + + require.NoError(t, updateClientErr) + assert.NotNil(t, commandService.commandServiceClient) +} + func TestCommandService_UpdateDataPlaneHealth(t *testing.T) { ctx := context.Background() commandServiceClient := &v1fakes.FakeCommandServiceClient{} diff --git a/internal/config/config.go b/internal/config/config.go index c321dda83..265e4c05f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -623,6 +623,12 @@ func registerClientFlags(fs *flag.FlagSet) { "File chunk size in bytes.", ) + fs.Duration( + ClientGRPCConnectionResetTimeoutKey, + DefGRPCConnectionResetTimeout, + "Duration to wait for in-progress management plane requests to complete before resetting the gRPC connection.", + ) + fs.Uint32( ClientGRPCMaxFileSizeKey, DefMaxFileSize, @@ -1119,6 +1125,7 @@ func resolveClient() *Client { FileChunkSize: viperInstance.GetUint32(ClientGRPCFileChunkSizeKey), ResponseTimeout: viperInstance.GetDuration(ClientGRPCResponseTimeoutKey), MaxParallelFileOperations: viperInstance.GetInt(ClientGRPCMaxParallelFileOperationsKey), + ConnectionResetTimeout: viperInstance.GetDuration(ClientGRPCConnectionResetTimeoutKey), }, Backoff: &BackOff{ InitialInterval: viperInstance.GetDuration(ClientBackoffInitialIntervalKey), diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 8a34e6dd8..2b29212fe 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -60,13 +60,14 @@ const ( DefAuxiliaryCommandTLServerNameKey = "" // Client GRPC Settings - DefMaxMessageSize = 0 // 0 = unset - DefMaxMessageRecieveSize = 4194304 // default 4 MB - DefMaxMessageSendSize = 4194304 // default 4 MB - DefMaxFileSize uint32 = 1048576 // 1MB - DefFileChunkSize uint32 = 524288 // 0.5MB - DefMaxParallelFileOperations = 5 - DefResponseTimeout = 10 * time.Second + DefMaxMessageSize = 0 // 0 = unset + DefMaxMessageRecieveSize = 4194304 // default 4 MB + DefMaxMessageSendSize = 4194304 // default 4 MB + DefMaxFileSize uint32 = 1048576 // 1MB + DefFileChunkSize uint32 = 524288 // 0.5MB + DefMaxParallelFileOperations = 5 + DefResponseTimeout = 10 * time.Second + DefGRPCConnectionResetTimeout = 3 * time.Minute // Client HTTP Settings DefHTTPTimeout = 10 * time.Second diff --git a/internal/config/flags.go b/internal/config/flags.go index a6ac9aac9..8295c1af7 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -41,6 +41,7 @@ var ( ClientGRPCMaxFileSizeKey = pre(ClientRootKey) + "grpc_max_file_size" ClientGRPCFileChunkSizeKey = pre(ClientRootKey) + "grpc_file_chunk_size" ClientGRPCMaxParallelFileOperationsKey = pre(ClientRootKey) + "grpc_max_parallel_file_operations" + ClientGRPCConnectionResetTimeoutKey = pre(ClientRootKey) + "grpc_connection_reset_timeout" ClientGRPCResponseTimeoutKey = pre(ClientRootKey) + "grpc_response_timeout" ClientBackoffInitialIntervalKey = pre(ClientRootKey) + "backoff_initial_interval" diff --git a/internal/config/types.go b/internal/config/types.go index 385be2a58..c62262fac 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -97,12 +97,13 @@ type ( ResponseTimeout time.Duration `yaml:"response_timeout" mapstructure:"response_timeout"` // if MaxMessageSize is size set then we use that value, // otherwise MaxMessageRecieveSize and MaxMessageSendSize for individual settings - MaxMessageSize int `yaml:"max_message_size" mapstructure:"max_message_size"` - MaxMessageReceiveSize int `yaml:"max_message_receive_size" mapstructure:"max_message_receive_size"` - MaxMessageSendSize int `yaml:"max_message_send_size" mapstructure:"max_message_send_size"` - MaxFileSize uint32 `yaml:"max_file_size" mapstructure:"max_file_size"` - FileChunkSize uint32 `yaml:"file_chunk_size" mapstructure:"file_chunk_size"` - MaxParallelFileOperations int `yaml:"max_parallel_file_operations" mapstructure:"max_parallel_file_operations"` + MaxMessageSize int `yaml:"max_message_size" mapstructure:"max_message_size"` + MaxMessageReceiveSize int `yaml:"max_message_receive_size" mapstructure:"max_message_receive_size"` + MaxMessageSendSize int `yaml:"max_message_send_size" mapstructure:"max_message_send_size"` + MaxFileSize uint32 `yaml:"max_file_size" mapstructure:"max_file_size"` + FileChunkSize uint32 `yaml:"file_chunk_size" mapstructure:"file_chunk_size"` + MaxParallelFileOperations int `yaml:"max_parallel_file_operations" mapstructure:"max_parallel_file_operations"` + ConnectionResetTimeout time.Duration `yaml:"connection_reset_timeout" mapstructure:"connection_reset_timeout"` } KeepAlive struct { diff --git a/internal/resource/nginx_instance_process_operator.go b/internal/resource/nginx_instance_process_operator.go index 4a13d5146..c8b2bb6b0 100644 --- a/internal/resource/nginx_instance_process_operator.go +++ b/internal/resource/nginx_instance_process_operator.go @@ -74,7 +74,7 @@ func (p *NginxInstanceProcessOperator) FindParentProcessID(ctx context.Context, } processInstanceID := id.Generate("%s_%s_%s", info.ExePath, info.ConfPath, info.Prefix) if instanceID == processInstanceID { - slog.DebugContext(ctx, "Found NGINX process ID", "process_id", processInstanceID) + slog.DebugContext(ctx, "Found NGINX process ID", "instance_id", processInstanceID) return proc.PID, nil } } diff --git a/internal/watcher/instance/nginx_process_parser.go b/internal/watcher/instance/nginx_process_parser.go index 4e9d02b6e..3da8786b8 100644 --- a/internal/watcher/instance/nginx_process_parser.go +++ b/internal/watcher/instance/nginx_process_parser.go @@ -46,14 +46,26 @@ func NewNginxProcessParser() *NginxProcessParser { // cognitive complexity of 16 because of the if statements in the for loop // don't think can be avoided due to the need for continue // -//nolint:revive // cognitive complexity of 20 because of the if statements in the for loop +//nolint:revive,gocognit // cognitive complexity of 20 because of the if statements in the for loop func (npp *NginxProcessParser) Parse(ctx context.Context, processes []*nginxprocess.Process) map[string]*mpi.Instance { + slog.DebugContext(ctx, "Parsing NGINX processes", "number_of_processes", len(processes)) + instanceMap := make(map[string]*mpi.Instance) // key is instanceID workers := make(map[int32][]*mpi.InstanceChild) // key is ppid of process processesByPID := convertToMap(processes) for _, proc := range processesByPID { + slog.DebugContext(ctx, "NGINX process details", + "ppid", proc.PPID, + "pid", proc.PID, + "name", proc.Name, + "created", proc.Created, + "status", proc.Status, + "cmd", proc.Cmd, + "exe", proc.Exe, + ) + if proc.IsWorker() { // Here we are determining if the worker process has a master if masterProcess, ok := processesByPID[proc.PPID]; ok { @@ -90,6 +102,15 @@ func (npp *NginxProcessParser) Parse(ctx context.Context, processes []*nginxproc // check if proc is a master process, process is not a worker but could be cache manager etc if proc.IsMaster() { + // sometimes a master process can have another master as parent + // which means that it is actually a worker process and not a master process + if masterProcess, ok := processesByPID[proc.PPID]; ok { + workers[masterProcess.PID] = append(workers[masterProcess.PID], + &mpi.InstanceChild{ProcessId: proc.PID}) + + continue + } + nginxInfo, err := npp.info(ctx, proc) if err != nil { slog.DebugContext(ctx, "Unable to get NGINX info", "pid", proc.PID, "error", err) diff --git a/internal/watcher/instance/nginx_process_parser_test.go b/internal/watcher/instance/nginx_process_parser_test.go index 409e7e976..9b53e85fd 100644 --- a/internal/watcher/instance/nginx_process_parser_test.go +++ b/internal/watcher/instance/nginx_process_parser_test.go @@ -228,6 +228,17 @@ func TestNginxProcessParser_Parse_Processes(t *testing.T) { instancesList[1].GetInstanceMeta().GetInstanceId(): instancesList[1], } + process6 := protos.NginxOssInstance(nil) + process6.GetInstanceRuntime().InstanceChildren = []*mpi.InstanceChild{ + {ProcessId: 567}, + {ProcessId: 789}, + {ProcessId: 5678}, + } + + instancesTest6 := map[string]*mpi.Instance{ + process6.GetInstanceMeta().GetInstanceId(): process6, + } + tests := []struct { expected map[string]*mpi.Instance name string @@ -368,6 +379,40 @@ func TestNginxProcessParser_Parse_Processes(t *testing.T) { }, expected: make(map[string]*mpi.Instance), }, + { + name: "Test 6: 1 master process each with 2 workers and 1 master process", + processes: []*nginxprocess.Process{ + { + PID: 1234, + PPID: 1, + Name: "nginx", + Cmd: "nginx: master process /usr/local/opt/nginx/bin/nginx -g daemon off;", + Exe: exePath, + }, + { + PID: 789, + PPID: 1234, + Name: "nginx", + Cmd: "nginx: worker process", + Exe: exePath, + }, + { + PID: 567, + PPID: 1234, + Name: "nginx", + Cmd: "nginx: worker process", + Exe: exePath, + }, + { + PID: 5678, + PPID: 1234, + Name: "nginx", + Cmd: "nginx: master process /usr/local/opt/nginx/bin/nginx -g daemon off;", + Exe: exePath, + }, + }, + expected: instancesTest6, + }, } for _, test := range tests { diff --git a/test/types/config.go b/test/types/config.go index 97f2165df..48e9f4b21 100644 --- a/test/types/config.go +++ b/test/types/config.go @@ -31,6 +31,7 @@ const ( maxParallelFileOperations = 5 reloadMonitoringPeriod = 400 * time.Millisecond + connectionResetTimeout = 200 * time.Millisecond ) // Produces a populated Agent Config for testing usage. @@ -58,6 +59,7 @@ func AgentConfig() *config.Config { MaxFileSize: 1, FileChunkSize: 1, MaxParallelFileOperations: maxParallelFileOperations, + ConnectionResetTimeout: connectionResetTimeout, }, Backoff: &config.BackOff{ InitialInterval: commonInitialInterval,