diff --git a/internal/command/command_plugin.go b/internal/command/command_plugin.go index 6a950bda0..293adfe14 100644 --- a/internal/command/command_plugin.go +++ b/internal/command/command_plugin.go @@ -8,6 +8,7 @@ package command import ( "context" "log/slog" + "sync" "google.golang.org/protobuf/types/known/timestamppb" @@ -30,7 +31,7 @@ type ( UpdateDataPlaneStatus(ctx context.Context, resource *mpi.Resource) error UpdateDataPlaneHealth(ctx context.Context, instanceHealths []*mpi.InstanceHealth) error SendDataPlaneResponse(ctx context.Context, response *mpi.DataPlaneResponse) error - CancelSubscription(ctx context.Context) + Subscribe(ctx context.Context) IsConnected() bool CreateConnection(ctx context.Context, resource *mpi.Resource) (*mpi.CreateConnectionResponse, error) } @@ -38,9 +39,11 @@ type ( CommandPlugin struct { messagePipe bus.MessagePipeInterface config *config.Config + subscribeCancel context.CancelFunc conn grpc.GrpcConnectionInterface commandService commandService subscribeChannel chan *mpi.ManagementPlaneRequest + subscribeMutex sync.Mutex } ) @@ -56,7 +59,7 @@ func (cp *CommandPlugin) Init(ctx context.Context, messagePipe bus.MessagePipeIn slog.DebugContext(ctx, "Starting command plugin") cp.messagePipe = messagePipe - cp.commandService = NewCommandService(ctx, cp.conn.CommandServiceClient(), cp.config, cp.subscribeChannel) + cp.commandService = NewCommandService(cp.conn.CommandServiceClient(), cp.config, cp.subscribeChannel) go cp.monitorSubscribeChannel(ctx) @@ -64,7 +67,14 @@ func (cp *CommandPlugin) Init(ctx context.Context, messagePipe bus.MessagePipeIn } func (cp *CommandPlugin) Close(ctx context.Context) error { - cp.commandService.CancelSubscription(ctx) + slog.InfoContext(ctx, "Canceling subscribe context") + + cp.subscribeMutex.Lock() + if cp.subscribeCancel != nil { + cp.subscribeCancel() + } + cp.subscribeMutex.Unlock() + return cp.conn.Close(ctx) } @@ -103,11 +113,20 @@ func (cp *CommandPlugin) processResourceUpdate(ctx context.Context, msg *bus.Mes } func (cp *CommandPlugin) createConnection(ctx context.Context, resource *mpi.Resource) { + var subscribeCtx context.Context + createConnectionResponse, err := cp.commandService.CreateConnection(ctx, resource) if err != nil { slog.ErrorContext(ctx, "Unable to create connection", "error", err) } + if createConnectionResponse != nil { + cp.subscribeMutex.Lock() + subscribeCtx, cp.subscribeCancel = context.WithCancel(ctx) + cp.subscribeMutex.Unlock() + + go cp.commandService.Subscribe(subscribeCtx) + cp.messagePipe.Process(ctx, &bus.Message{ Topic: bus.ConnectionCreatedTopic, Data: createConnectionResponse, diff --git a/internal/command/command_plugin_test.go b/internal/command/command_plugin_test.go index c98713f9f..90f56c394 100644 --- a/internal/command/command_plugin_test.go +++ b/internal/command/command_plugin_test.go @@ -70,7 +70,39 @@ func TestCommandPlugin_Init(t *testing.T) { closeError := commandPlugin.Close(ctx) require.NoError(t, closeError) - require.Equal(t, 1, fakeCommandService.CancelSubscriptionCallCount()) +} + +func TestCommandPlugin_createConnection(t *testing.T) { + ctx := context.Background() + commandService := &commandfakes.FakeCommandService{} + commandService.CreateConnectionReturns(&mpi.CreateConnectionResponse{}, nil) + messagePipe := busfakes.NewFakeMessagePipe() + + commandPlugin := NewCommandPlugin(types.AgentConfig(), &grpcfakes.FakeGrpcConnectionInterface{}) + err := commandPlugin.Init(ctx, messagePipe) + commandPlugin.commandService = commandService + require.NoError(t, err) + defer commandPlugin.Close(ctx) + + commandPlugin.createConnection(ctx, &mpi.Resource{}) + + assert.Eventually( + t, + func() bool { return commandService.SubscribeCallCount() > 0 }, + 2*time.Second, + 10*time.Millisecond, + ) + + assert.Eventually( + t, + func() bool { return len(messagePipe.GetMessages()) == 1 }, + 2*time.Second, + 10*time.Millisecond, + ) + + messages := messagePipe.GetMessages() + assert.Len(t, messages, 1) + assert.Equal(t, bus.ConnectionCreatedTopic, messages[0].Topic) } func TestCommandPlugin_Process(t *testing.T) { @@ -307,12 +339,12 @@ func TestCommandPlugin_FeatureDisabled(t *testing.T) { func TestMonitorSubscribeChannel(t *testing.T) { ctx, cncl := context.WithCancel(context.Background()) - defer cncl() logBuf := &bytes.Buffer{} stub.StubLoggerWith(logBuf) cp := NewCommandPlugin(types.AgentConfig(), &grpcfakes.FakeGrpcConnectionInterface{}) + cp.subscribeCancel = cncl message := protos.CreateManagementPlaneRequest() @@ -327,7 +359,7 @@ func TestMonitorSubscribeChannel(t *testing.T) { // Give some time to process the message time.Sleep(100 * time.Millisecond) - cncl() + cp.Close(ctx) time.Sleep(100 * time.Millisecond) diff --git a/internal/command/command_service.go b/internal/command/command_service.go index 7dd20c966..3c569e7f8 100644 --- a/internal/command/command_service.go +++ b/internal/command/command_service.go @@ -41,11 +41,9 @@ type ( subscribeClient mpi.CommandService_SubscribeClient agentConfig *config.Config isConnected *atomic.Bool - subscribeCancel context.CancelFunc subscribeChannel chan *mpi.ManagementPlaneRequest configApplyRequestQueue map[string][]*mpi.ManagementPlaneRequest // key is the instance ID resource *mpi.Resource - subscribeMutex sync.Mutex subscribeClientMutex sync.Mutex configApplyRequestQueueMutex sync.Mutex resourceMutex sync.Mutex @@ -53,7 +51,6 @@ type ( ) func NewCommandService( - ctx context.Context, commandServiceClient mpi.CommandServiceClient, agentConfig *config.Config, subscribeChannel chan *mpi.ManagementPlaneRequest, @@ -70,14 +67,6 @@ func NewCommandService( resource: &mpi.Resource{}, } - var subscribeCtx context.Context - - commandService.subscribeMutex.Lock() - subscribeCtx, commandService.subscribeCancel = context.WithCancel(ctx) - commandService.subscribeMutex.Unlock() - - go commandService.subscribe(subscribeCtx) - return commandService } @@ -190,17 +179,7 @@ func (cs *CommandService) SendDataPlaneResponse(ctx context.Context, response *m ) } -func (cs *CommandService) CancelSubscription(ctx context.Context) { - slog.InfoContext(ctx, "Canceling subscribe context") - - cs.subscribeMutex.Lock() - if cs.subscribeCancel != nil { - cs.subscribeCancel() - } - cs.subscribeMutex.Unlock() -} - -func (cs *CommandService) subscribe(ctx context.Context) { +func (cs *CommandService) Subscribe(ctx context.Context) { commonSettings := &config.BackOff{ InitialInterval: cs.agentConfig.Client.Backoff.InitialInterval, MaxInterval: cs.agentConfig.Client.Backoff.MaxInterval, diff --git a/internal/command/command_service_test.go b/internal/command/command_service_test.go index 899653f8f..0dfb00e95 100644 --- a/internal/command/command_service_test.go +++ b/internal/command/command_service_test.go @@ -77,30 +77,10 @@ func (*FakeConfigApplySubscribeClient) Recv() (*mpi.ManagementPlaneRequest, erro }, nil } -func TestCommandService_NewCommandService(t *testing.T) { - ctx := context.Background() - commandServiceClient := &v1fakes.FakeCommandServiceClient{} - - commandService := NewCommandService( - ctx, - commandServiceClient, - types.AgentConfig(), - make(chan *mpi.ManagementPlaneRequest), - ) - - defer commandService.CancelSubscription(ctx) - - assert.Eventually( - t, - func() bool { return commandServiceClient.SubscribeCallCount() > 0 }, - 2*time.Second, - 10*time.Millisecond, - ) -} - func TestCommandService_receiveCallback_configApplyRequest(t *testing.T) { - ctx := context.Background() fakeSubscribeClient := &FakeConfigApplySubscribeClient{} + ctx := context.Background() + subscribeCtx, subscribeCancel := context.WithCancel(ctx) commandServiceClient := &v1fakes.FakeCommandServiceClient{} commandServiceClient.SubscribeReturns(fakeSubscribeClient, nil) @@ -108,19 +88,18 @@ func TestCommandService_receiveCallback_configApplyRequest(t *testing.T) { subscribeChannel := make(chan *mpi.ManagementPlaneRequest) commandService := NewCommandService( - ctx, commandServiceClient, types.AgentConfig(), subscribeChannel, ) + go commandService.Subscribe(subscribeCtx) + defer subscribeCancel() nginxInstance := protos.GetNginxOssInstance([]string{}) commandService.resourceMutex.Lock() commandService.resource.Instances = append(commandService.resource.Instances, nginxInstance) commandService.resourceMutex.Unlock() - defer commandService.CancelSubscription(ctx) - var wg sync.WaitGroup wg.Add(1) @@ -152,13 +131,10 @@ func TestCommandService_UpdateDataPlaneStatus(t *testing.T) { commandServiceClient.SubscribeReturns(fakeSubscribeClient, nil) commandService := NewCommandService( - ctx, commandServiceClient, types.AgentConfig(), make(chan *mpi.ManagementPlaneRequest), ) - defer commandService.CancelSubscription(ctx) - // Fail first time since there are no other instances besides the agent err := commandService.UpdateDataPlaneStatus(ctx, protos.GetHostResource()) require.Error(t, err) @@ -191,12 +167,10 @@ func TestCommandService_UpdateDataPlaneStatusSubscribeError(t *testing.T) { stub.StubLoggerWith(logBuf) commandService := NewCommandService( - ctx, commandServiceClient, types.AgentConfig(), make(chan *mpi.ManagementPlaneRequest), ) - defer commandService.CancelSubscription(ctx) commandService.isConnected.Store(true) @@ -213,7 +187,6 @@ func TestCommandService_CreateConnection(t *testing.T) { commandServiceClient := &v1fakes.FakeCommandServiceClient{} commandService := NewCommandService( - ctx, commandServiceClient, types.AgentConfig(), make(chan *mpi.ManagementPlaneRequest), @@ -230,7 +203,6 @@ func TestCommandService_UpdateDataPlaneHealth(t *testing.T) { commandServiceClient := &v1fakes.FakeCommandServiceClient{} commandService := NewCommandService( - ctx, commandServiceClient, types.AgentConfig(), make(chan *mpi.ManagementPlaneRequest), @@ -261,7 +233,6 @@ func TestCommandService_SendDataPlaneResponse(t *testing.T) { subscribeClient := &FakeSubscribeClient{} commandService := NewCommandService( - ctx, commandServiceClient, types.AgentConfig(), make(chan *mpi.ManagementPlaneRequest), @@ -283,14 +254,11 @@ func TestCommandService_SendDataPlaneResponse_configApplyRequest(t *testing.T) { subscribeChannel := make(chan *mpi.ManagementPlaneRequest) commandService := NewCommandService( - ctx, commandServiceClient, types.AgentConfig(), subscribeChannel, ) - defer commandService.CancelSubscription(ctx) - request1 := &mpi.ManagementPlaneRequest{ MessageMeta: &mpi.MessageMeta{ MessageId: "1", @@ -402,7 +370,6 @@ func TestCommandService_isValidRequest(t *testing.T) { subscribeClient := &FakeSubscribeClient{} commandService := NewCommandService( - ctx, commandServiceClient, types.AgentConfig(), make(chan *mpi.ManagementPlaneRequest), diff --git a/internal/command/commandfakes/fake_command_service.go b/internal/command/commandfakes/fake_command_service.go index 2df848b5e..0748ce080 100644 --- a/internal/command/commandfakes/fake_command_service.go +++ b/internal/command/commandfakes/fake_command_service.go @@ -9,11 +9,6 @@ import ( ) type FakeCommandService struct { - CancelSubscriptionStub func(context.Context) - cancelSubscriptionMutex sync.RWMutex - cancelSubscriptionArgsForCall []struct { - arg1 context.Context - } CreateConnectionStub func(context.Context, *v1.Resource) (*v1.CreateConnectionResponse, error) createConnectionMutex sync.RWMutex createConnectionArgsForCall []struct { @@ -50,6 +45,11 @@ type FakeCommandService struct { sendDataPlaneResponseReturnsOnCall map[int]struct { result1 error } + SubscribeStub func(context.Context) + subscribeMutex sync.RWMutex + subscribeArgsForCall []struct { + arg1 context.Context + } UpdateDataPlaneHealthStub func(context.Context, []*v1.InstanceHealth) error updateDataPlaneHealthMutex sync.RWMutex updateDataPlaneHealthArgsForCall []struct { @@ -78,38 +78,6 @@ type FakeCommandService struct { invocationsMutex sync.RWMutex } -func (fake *FakeCommandService) CancelSubscription(arg1 context.Context) { - fake.cancelSubscriptionMutex.Lock() - fake.cancelSubscriptionArgsForCall = append(fake.cancelSubscriptionArgsForCall, struct { - arg1 context.Context - }{arg1}) - stub := fake.CancelSubscriptionStub - fake.recordInvocation("CancelSubscription", []interface{}{arg1}) - fake.cancelSubscriptionMutex.Unlock() - if stub != nil { - fake.CancelSubscriptionStub(arg1) - } -} - -func (fake *FakeCommandService) CancelSubscriptionCallCount() int { - fake.cancelSubscriptionMutex.RLock() - defer fake.cancelSubscriptionMutex.RUnlock() - return len(fake.cancelSubscriptionArgsForCall) -} - -func (fake *FakeCommandService) CancelSubscriptionCalls(stub func(context.Context)) { - fake.cancelSubscriptionMutex.Lock() - defer fake.cancelSubscriptionMutex.Unlock() - fake.CancelSubscriptionStub = stub -} - -func (fake *FakeCommandService) CancelSubscriptionArgsForCall(i int) context.Context { - fake.cancelSubscriptionMutex.RLock() - defer fake.cancelSubscriptionMutex.RUnlock() - argsForCall := fake.cancelSubscriptionArgsForCall[i] - return argsForCall.arg1 -} - func (fake *FakeCommandService) CreateConnection(arg1 context.Context, arg2 *v1.Resource) (*v1.CreateConnectionResponse, error) { fake.createConnectionMutex.Lock() ret, specificReturn := fake.createConnectionReturnsOnCall[len(fake.createConnectionArgsForCall)] @@ -290,6 +258,38 @@ func (fake *FakeCommandService) SendDataPlaneResponseReturnsOnCall(i int, result }{result1} } +func (fake *FakeCommandService) Subscribe(arg1 context.Context) { + fake.subscribeMutex.Lock() + fake.subscribeArgsForCall = append(fake.subscribeArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.SubscribeStub + fake.recordInvocation("Subscribe", []interface{}{arg1}) + fake.subscribeMutex.Unlock() + if stub != nil { + fake.SubscribeStub(arg1) + } +} + +func (fake *FakeCommandService) SubscribeCallCount() int { + fake.subscribeMutex.RLock() + defer fake.subscribeMutex.RUnlock() + return len(fake.subscribeArgsForCall) +} + +func (fake *FakeCommandService) SubscribeCalls(stub func(context.Context)) { + fake.subscribeMutex.Lock() + defer fake.subscribeMutex.Unlock() + fake.SubscribeStub = stub +} + +func (fake *FakeCommandService) SubscribeArgsForCall(i int) context.Context { + fake.subscribeMutex.RLock() + defer fake.subscribeMutex.RUnlock() + argsForCall := fake.subscribeArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeCommandService) UpdateDataPlaneHealth(arg1 context.Context, arg2 []*v1.InstanceHealth) error { var arg2Copy []*v1.InstanceHealth if arg2 != nil { @@ -422,14 +422,14 @@ func (fake *FakeCommandService) UpdateDataPlaneStatusReturnsOnCall(i int, result func (fake *FakeCommandService) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() - fake.cancelSubscriptionMutex.RLock() - defer fake.cancelSubscriptionMutex.RUnlock() fake.createConnectionMutex.RLock() defer fake.createConnectionMutex.RUnlock() fake.isConnectedMutex.RLock() defer fake.isConnectedMutex.RUnlock() fake.sendDataPlaneResponseMutex.RLock() defer fake.sendDataPlaneResponseMutex.RUnlock() + fake.subscribeMutex.RLock() + defer fake.subscribeMutex.RUnlock() fake.updateDataPlaneHealthMutex.RLock() defer fake.updateDataPlaneHealthMutex.RUnlock() fake.updateDataPlaneStatusMutex.RLock()