@@ -12,6 +12,7 @@ import (
1212 "log/slog"
1313 "sync"
1414 "sync/atomic"
15+ "time"
1516
1617 "google.golang.org/grpc/codes"
1718 "google.golang.org/grpc/status"
@@ -33,6 +34,7 @@ var _ commandService = (*CommandService)(nil)
3334
3435const (
3536 createConnectionMaxElapsedTime = 0
37+ timeToWaitBetweenChecks = 5 * time .Second
3638)
3739
3840type (
@@ -41,8 +43,10 @@ type (
4143 subscribeClient mpi.CommandService_SubscribeClient
4244 agentConfig * config.Config
4345 isConnected * atomic.Bool
46+ connectionResetInProgress * atomic.Bool
4447 subscribeChannel chan * mpi.ManagementPlaneRequest
4548 configApplyRequestQueue map [string ][]* mpi.ManagementPlaneRequest // key is the instance ID
49+ requestsInProgress map [string ]* mpi.ManagementPlaneRequest // key is the correlation ID
4650 resource * mpi.Resource
4751 subscribeClientMutex sync.Mutex
4852 configApplyRequestQueueMutex sync.Mutex
@@ -55,19 +59,16 @@ func NewCommandService(
5559 agentConfig * config.Config ,
5660 subscribeChannel chan * mpi.ManagementPlaneRequest ,
5761) * CommandService {
58- isConnected := & atomic.Bool {}
59- isConnected .Store (false )
60-
61- commandService := & CommandService {
62- commandServiceClient : commandServiceClient ,
63- agentConfig : agentConfig ,
64- isConnected : isConnected ,
65- subscribeChannel : subscribeChannel ,
66- configApplyRequestQueue : make (map [string ][]* mpi.ManagementPlaneRequest ),
67- resource : & mpi.Resource {},
62+ return & CommandService {
63+ commandServiceClient : commandServiceClient ,
64+ agentConfig : agentConfig ,
65+ isConnected : & atomic.Bool {},
66+ connectionResetInProgress : & atomic.Bool {},
67+ subscribeChannel : subscribeChannel ,
68+ configApplyRequestQueue : make (map [string ][]* mpi.ManagementPlaneRequest ),
69+ resource : & mpi.Resource {},
70+ requestsInProgress : make (map [string ]* mpi.ManagementPlaneRequest ),
6871 }
69-
70- return commandService
7172}
7273
7374func (cs * CommandService ) IsConnected () bool {
@@ -176,6 +177,11 @@ func (cs *CommandService) SendDataPlaneResponse(ctx context.Context, response *m
176177 return err
177178 }
178179
180+ if response .GetCommandResponse ().GetStatus () == mpi .CommandResponse_COMMAND_STATUS_OK ||
181+ response .GetCommandResponse ().GetStatus () == mpi .CommandResponse_COMMAND_STATUS_FAILURE {
182+ delete (cs .requestsInProgress , response .GetMessageMeta ().GetCorrelationId ())
183+ }
184+
179185 return backoff .Retry (
180186 cs .sendDataPlaneResponseCallback (ctx , response ),
181187 backoffHelpers .Context (backOffCtx , cs .agentConfig .Client .Backoff ),
@@ -256,6 +262,33 @@ func (cs *CommandService) CreateConnection(
256262}
257263
258264func (cs * CommandService ) UpdateClient (ctx context.Context , client mpi.CommandServiceClient ) error {
265+ cs .connectionResetInProgress .Store (true )
266+ defer cs .connectionResetInProgress .Store (false )
267+
268+ // Wait for any in-progress requests to complete before updating the client
269+ start := time .Now ()
270+
271+ for len (cs .requestsInProgress ) > 0 {
272+ if time .Since (start ) >= cs .agentConfig .Client .Grpc .ConnectionResetTimeout {
273+ slog .WarnContext (
274+ ctx ,
275+ "Timeout reached while waiting for in-progress requests to complete" ,
276+ "number_of_requests_in_progress" , len (cs .requestsInProgress ),
277+ )
278+
279+ break
280+ }
281+
282+ slog .InfoContext (
283+ ctx ,
284+ "Waiting for in-progress requests to complete before updating command service gRPC client" ,
285+ "max_wait_time" , cs .agentConfig .Client .Grpc .ConnectionResetTimeout ,
286+ "number_of_requests_in_progress" , len (cs .requestsInProgress ),
287+ )
288+
289+ time .Sleep (timeToWaitBetweenChecks )
290+ }
291+
259292 cs .subscribeClientMutex .Lock ()
260293 cs .commandServiceClient = client
261294 cs .subscribeClientMutex .Unlock ()
@@ -363,7 +396,7 @@ func (cs *CommandService) sendResponseForQueuedConfigApplyRequests(
363396 cs .configApplyRequestQueue [instanceID ] = cs .configApplyRequestQueue [instanceID ][indexOfConfigApplyRequest + 1 :]
364397 slog .DebugContext (ctx , "Removed config apply requests from queue" , "queue" , cs .configApplyRequestQueue [instanceID ])
365398
366- if len (cs .configApplyRequestQueue [instanceID ]) > 0 {
399+ if len (cs .configApplyRequestQueue [instanceID ]) > 0 && ! cs . connectionResetInProgress . Load () {
367400 cs .subscribeChannel <- cs .configApplyRequestQueue [instanceID ][len (cs .configApplyRequestQueue [instanceID ])- 1 ]
368401 }
369402
@@ -404,6 +437,12 @@ func (cs *CommandService) dataPlaneHealthCallback(
404437//nolint:revive // cognitive complexity is 18
405438func (cs * CommandService ) receiveCallback (ctx context.Context ) func () error {
406439 return func () error {
440+ if cs .connectionResetInProgress .Load () {
441+ slog .DebugContext (ctx , "Connection reset in progress, skipping receive from subscribe stream" )
442+
443+ return nil
444+ }
445+
407446 cs .subscribeClientMutex .Lock ()
408447
409448 if cs .subscribeClient == nil {
@@ -444,6 +483,8 @@ func (cs *CommandService) receiveCallback(ctx context.Context) func() error {
444483 default :
445484 cs .subscribeChannel <- request
446485 }
486+
487+ cs .requestsInProgress [request .GetMessageMeta ().GetCorrelationId ()] = request
447488 }
448489
449490 return nil
@@ -476,7 +517,7 @@ func (cs *CommandService) queueConfigApplyRequests(ctx context.Context, request
476517
477518 instanceID := request .GetConfigApplyRequest ().GetOverview ().GetConfigVersion ().GetInstanceId ()
478519 cs .configApplyRequestQueue [instanceID ] = append (cs .configApplyRequestQueue [instanceID ], request )
479- if len (cs .configApplyRequestQueue [instanceID ]) == 1 {
520+ if len (cs .configApplyRequestQueue [instanceID ]) == 1 && ! cs . connectionResetInProgress . Load () {
480521 cs .subscribeChannel <- request
481522 } else {
482523 slog .DebugContext (
0 commit comments