Skip to content

Commit cd41a35

Browse files
committed
Ensure in progress requests are processed before resetting the gRPC connection with the management plane
1 parent b1947ec commit cd41a35

File tree

7 files changed

+97
-35
lines changed

7 files changed

+97
-35
lines changed

internal/command/command_plugin.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ func (cp *CommandPlugin) Process(ctx context.Context, msg *bus.Message) {
118118
if logger.ServerType(ctxWithMetadata) == cp.commandServerType.String() {
119119
switch msg.Topic {
120120
case bus.ConnectionResetTopic:
121-
cp.processConnectionReset(ctxWithMetadata, msg)
121+
// Running as a separate go routine so that the command plugin can continue to process data plane responses
122+
// while the connection reset is in progress
123+
go cp.processConnectionReset(ctxWithMetadata, msg)
122124
case bus.ResourceUpdateTopic:
123125
cp.processResourceUpdate(ctxWithMetadata, msg)
124126
case bus.InstanceHealthTopic:
@@ -232,11 +234,19 @@ func (cp *CommandPlugin) processConnectionReset(ctx context.Context, msg *bus.Me
232234
slog.DebugContext(ctx, "Command plugin received connection reset message")
233235

234236
if newConnection, ok := msg.Data.(grpc.GrpcConnectionInterface); ok {
235-
slog.DebugContext(ctx, "Canceling Subscribe after connection reset")
236237
ctxWithMetadata := cp.config.NewContextWithLabels(ctx)
237238
cp.subscribeMutex.Lock()
238239
defer cp.subscribeMutex.Unlock()
239240

241+
// Update the command service with the new client first
242+
err := cp.commandService.UpdateClient(ctxWithMetadata, newConnection.CommandServiceClient())
243+
if err != nil {
244+
slog.ErrorContext(ctx, "Failed to reset connection", "error", err)
245+
return
246+
}
247+
248+
// Once the command service is updated, we close the old connection
249+
slog.DebugContext(ctx, "Canceling Subscribe after connection reset")
240250
if cp.subscribeCancel != nil {
241251
cp.subscribeCancel()
242252
slog.DebugContext(ctxWithMetadata, "Successfully canceled subscribe after connection reset")
@@ -248,12 +258,6 @@ func (cp *CommandPlugin) processConnectionReset(ctx context.Context, msg *bus.Me
248258
}
249259

250260
cp.conn = newConnection
251-
err := cp.commandService.UpdateClient(ctx, cp.conn.CommandServiceClient())
252-
if err != nil {
253-
slog.ErrorContext(ctx, "Failed to reset connection", "error", err)
254-
return
255-
}
256-
257261
slog.DebugContext(ctxWithMetadata, "Starting new subscribe after connection reset")
258262
subscribeCtx, cp.subscribeCancel = context.WithCancel(ctxWithMetadata)
259263
go cp.commandService.Subscribe(subscribeCtx)

internal/command/command_plugin_test.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,14 @@ func TestCommandPlugin_Process(t *testing.T) {
150150
Topic: bus.ConnectionResetTopic,
151151
Data: commandPlugin.conn,
152152
})
153-
require.Equal(t, 1, fakeCommandService.UpdateClientCallCount())
153+
154+
// Separate goroutine is executed so need to wait for it to complete
155+
assert.Eventually(
156+
t,
157+
func() bool { return fakeCommandService.UpdateClientCallCount() == 1 },
158+
2*time.Second,
159+
10*time.Millisecond,
160+
)
154161
}
155162

156163
func TestCommandPlugin_monitorSubscribeChannel(t *testing.T) {

internal/command/command_service.go

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"log/slog"
1313
"sync"
1414
"sync/atomic"
15+
"time"
1516

1617
"google.golang.org/grpc/codes"
1718
"google.golang.org/grpc/status"
@@ -33,6 +34,7 @@ var _ commandService = (*CommandService)(nil)
3334

3435
const (
3536
createConnectionMaxElapsedTime = 0
37+
timeToWaitBetweenChecks = 5 * time.Second
3638
)
3739

3840
type (
@@ -41,8 +43,10 @@ type (
4143
subscribeClient mpi.CommandService_SubscribeClient
4244
agentConfig *config.Config
4345
isConnected *atomic.Bool
46+
connectionResetInProgress *atomic.Bool
4447
subscribeChannel chan *mpi.ManagementPlaneRequest
4548
configApplyRequestQueue map[string][]*mpi.ManagementPlaneRequest // key is the instance ID
49+
requestsInProgress map[string]*mpi.ManagementPlaneRequest // key is the correlation ID
4650
resource *mpi.Resource
4751
subscribeClientMutex sync.Mutex
4852
configApplyRequestQueueMutex sync.Mutex
@@ -55,19 +59,16 @@ func NewCommandService(
5559
agentConfig *config.Config,
5660
subscribeChannel chan *mpi.ManagementPlaneRequest,
5761
) *CommandService {
58-
isConnected := &atomic.Bool{}
59-
isConnected.Store(false)
60-
61-
commandService := &CommandService{
62-
commandServiceClient: commandServiceClient,
63-
agentConfig: agentConfig,
64-
isConnected: isConnected,
65-
subscribeChannel: subscribeChannel,
66-
configApplyRequestQueue: make(map[string][]*mpi.ManagementPlaneRequest),
67-
resource: &mpi.Resource{},
62+
return &CommandService{
63+
commandServiceClient: commandServiceClient,
64+
agentConfig: agentConfig,
65+
isConnected: &atomic.Bool{},
66+
connectionResetInProgress: &atomic.Bool{},
67+
subscribeChannel: subscribeChannel,
68+
configApplyRequestQueue: make(map[string][]*mpi.ManagementPlaneRequest),
69+
resource: &mpi.Resource{},
70+
requestsInProgress: make(map[string]*mpi.ManagementPlaneRequest),
6871
}
69-
70-
return commandService
7172
}
7273

7374
func (cs *CommandService) IsConnected() bool {
@@ -176,6 +177,11 @@ func (cs *CommandService) SendDataPlaneResponse(ctx context.Context, response *m
176177
return err
177178
}
178179

180+
if response.GetCommandResponse().GetStatus() == mpi.CommandResponse_COMMAND_STATUS_OK ||
181+
response.GetCommandResponse().GetStatus() == mpi.CommandResponse_COMMAND_STATUS_FAILURE {
182+
delete(cs.requestsInProgress, response.GetMessageMeta().GetCorrelationId())
183+
}
184+
179185
return backoff.Retry(
180186
cs.sendDataPlaneResponseCallback(ctx, response),
181187
backoffHelpers.Context(backOffCtx, cs.agentConfig.Client.Backoff),
@@ -256,6 +262,33 @@ func (cs *CommandService) CreateConnection(
256262
}
257263

258264
func (cs *CommandService) UpdateClient(ctx context.Context, client mpi.CommandServiceClient) error {
265+
cs.connectionResetInProgress.Store(true)
266+
defer cs.connectionResetInProgress.Store(false)
267+
268+
// Wait for any in-progress requests to complete before updating the client
269+
start := time.Now()
270+
271+
for len(cs.requestsInProgress) > 0 {
272+
if time.Since(start) >= cs.agentConfig.Client.Grpc.ConnectionResetTimeout {
273+
slog.WarnContext(
274+
ctx,
275+
"Timeout reached while waiting for in-progress requests to complete",
276+
"number_of_requests_in_progress", len(cs.requestsInProgress),
277+
)
278+
279+
break
280+
}
281+
282+
slog.InfoContext(
283+
ctx,
284+
"Waiting for in-progress requests to complete before updating command service gRPC client",
285+
"max_wait_time", cs.agentConfig.Client.Grpc.ConnectionResetTimeout,
286+
"number_of_requests_in_progress", len(cs.requestsInProgress),
287+
)
288+
289+
time.Sleep(timeToWaitBetweenChecks)
290+
}
291+
259292
cs.subscribeClientMutex.Lock()
260293
cs.commandServiceClient = client
261294
cs.subscribeClientMutex.Unlock()
@@ -363,7 +396,7 @@ func (cs *CommandService) sendResponseForQueuedConfigApplyRequests(
363396
cs.configApplyRequestQueue[instanceID] = cs.configApplyRequestQueue[instanceID][indexOfConfigApplyRequest+1:]
364397
slog.DebugContext(ctx, "Removed config apply requests from queue", "queue", cs.configApplyRequestQueue[instanceID])
365398

366-
if len(cs.configApplyRequestQueue[instanceID]) > 0 {
399+
if len(cs.configApplyRequestQueue[instanceID]) > 0 && !cs.connectionResetInProgress.Load() {
367400
cs.subscribeChannel <- cs.configApplyRequestQueue[instanceID][len(cs.configApplyRequestQueue[instanceID])-1]
368401
}
369402

@@ -404,6 +437,12 @@ func (cs *CommandService) dataPlaneHealthCallback(
404437
//nolint:revive // cognitive complexity is 18
405438
func (cs *CommandService) receiveCallback(ctx context.Context) func() error {
406439
return func() error {
440+
if cs.connectionResetInProgress.Load() {
441+
slog.DebugContext(ctx, "Connection reset in progress, skipping receive from subscribe stream")
442+
443+
return nil
444+
}
445+
407446
cs.subscribeClientMutex.Lock()
408447

409448
if cs.subscribeClient == nil {
@@ -444,6 +483,8 @@ func (cs *CommandService) receiveCallback(ctx context.Context) func() error {
444483
default:
445484
cs.subscribeChannel <- request
446485
}
486+
487+
cs.requestsInProgress[request.GetMessageMeta().GetCorrelationId()] = request
447488
}
448489

449490
return nil
@@ -476,7 +517,7 @@ func (cs *CommandService) queueConfigApplyRequests(ctx context.Context, request
476517

477518
instanceID := request.GetConfigApplyRequest().GetOverview().GetConfigVersion().GetInstanceId()
478519
cs.configApplyRequestQueue[instanceID] = append(cs.configApplyRequestQueue[instanceID], request)
479-
if len(cs.configApplyRequestQueue[instanceID]) == 1 {
520+
if len(cs.configApplyRequestQueue[instanceID]) == 1 && !cs.connectionResetInProgress.Load() {
480521
cs.subscribeChannel <- request
481522
} else {
482523
slog.DebugContext(

internal/config/config.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,12 @@ func registerClientFlags(fs *flag.FlagSet) {
623623
"File chunk size in bytes.",
624624
)
625625

626+
fs.Duration(
627+
ClientGRPCConnectionResetTimeoutKey,
628+
DefGRPCConnectionResetTimeout,
629+
"Duration to wait for in-progress management plane requests to complete before resetting the gRPC connection.",
630+
)
631+
626632
fs.Uint32(
627633
ClientGRPCMaxFileSizeKey,
628634
DefMaxFileSize,
@@ -1112,6 +1118,7 @@ func resolveClient() *Client {
11121118
MaxFileSize: viperInstance.GetUint32(ClientGRPCMaxFileSizeKey),
11131119
FileChunkSize: viperInstance.GetUint32(ClientGRPCFileChunkSizeKey),
11141120
MaxParallelFileOperations: viperInstance.GetInt(ClientGRPCMaxParallelFileOperationsKey),
1121+
ConnectionResetTimeout: viperInstance.GetDuration(ClientGRPCConnectionResetTimeoutKey),
11151122
},
11161123
Backoff: &BackOff{
11171124
InitialInterval: viperInstance.GetDuration(ClientBackoffInitialIntervalKey),

internal/config/defaults.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,13 @@ const (
6060
DefAuxiliaryCommandTLServerNameKey = ""
6161

6262
// Client GRPC Settings
63-
DefMaxMessageSize = 0 // 0 = unset
64-
DefMaxMessageRecieveSize = 4194304 // default 4 MB
65-
DefMaxMessageSendSize = 4194304 // default 4 MB
66-
DefMaxFileSize uint32 = 1048576 // 1MB
67-
DefFileChunkSize uint32 = 524288 // 0.5MB
68-
DefMaxParallelFileOperations = 5
63+
DefMaxMessageSize = 0 // 0 = unset
64+
DefMaxMessageRecieveSize = 4194304 // default 4 MB
65+
DefMaxMessageSendSize = 4194304 // default 4 MB
66+
DefMaxFileSize uint32 = 1048576 // 1MB
67+
DefFileChunkSize uint32 = 524288 // 0.5MB
68+
DefMaxParallelFileOperations = 5
69+
DefGRPCConnectionResetTimeout = 3 * time.Minute
6970

7071
// Client HTTP Settings
7172
DefHTTPTimeout = 10 * time.Second

internal/config/flags.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ var (
4141
ClientGRPCMaxFileSizeKey = pre(ClientRootKey) + "grpc_max_file_size"
4242
ClientGRPCFileChunkSizeKey = pre(ClientRootKey) + "grpc_file_chunk_size"
4343
ClientGRPCMaxParallelFileOperationsKey = pre(ClientRootKey) + "grpc_max_parallel_file_operations"
44+
ClientGRPCConnectionResetTimeoutKey = pre(ClientRootKey) + "grpc_connection_reset_timeout"
4445

4546
ClientBackoffInitialIntervalKey = pre(ClientRootKey) + "backoff_initial_interval"
4647
ClientBackoffMaxIntervalKey = pre(ClientRootKey) + "backoff_max_interval"

internal/config/types.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,13 @@ type (
9696
KeepAlive *KeepAlive `yaml:"keepalive" mapstructure:"keepalive"`
9797
// if MaxMessageSize is size set then we use that value,
9898
// otherwise MaxMessageRecieveSize and MaxMessageSendSize for individual settings
99-
MaxMessageSize int `yaml:"max_message_size" mapstructure:"max_message_size"`
100-
MaxMessageReceiveSize int `yaml:"max_message_receive_size" mapstructure:"max_message_receive_size"`
101-
MaxMessageSendSize int `yaml:"max_message_send_size" mapstructure:"max_message_send_size"`
102-
MaxFileSize uint32 `yaml:"max_file_size" mapstructure:"max_file_size"`
103-
FileChunkSize uint32 `yaml:"file_chunk_size" mapstructure:"file_chunk_size"`
104-
MaxParallelFileOperations int `yaml:"max_parallel_file_operations" mapstructure:"max_parallel_file_operations"`
99+
MaxMessageSize int `yaml:"max_message_size" mapstructure:"max_message_size"`
100+
MaxMessageReceiveSize int `yaml:"max_message_receive_size" mapstructure:"max_message_receive_size"`
101+
MaxMessageSendSize int `yaml:"max_message_send_size" mapstructure:"max_message_send_size"`
102+
MaxFileSize uint32 `yaml:"max_file_size" mapstructure:"max_file_size"`
103+
FileChunkSize uint32 `yaml:"file_chunk_size" mapstructure:"file_chunk_size"`
104+
MaxParallelFileOperations int `yaml:"max_parallel_file_operations" mapstructure:"max_parallel_file_operations"`
105+
ConnectionResetTimeout time.Duration `yaml:"connection_reset_timeout" mapstructure:"connection_reset_timeout"`
105106
}
106107

107108
KeepAlive struct {

0 commit comments

Comments
 (0)